1.04x speedup: Parallelize SoftCap

Also require opt-in constexpr flag for observer callbacks, update zones

PiperOrigin-RevId: 799655163
This commit is contained in:
Jan Wassenberg 2025-08-26 11:54:48 -07:00 committed by Copybara-Service
parent ed2f0bd1b0
commit 86afd53076
2 changed files with 45 additions and 16 deletions

View File

@ -61,6 +61,9 @@
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// Require opt-in to debug/introspection functions to eliminate their overhead.
HWY_INLINE_VAR constexpr bool kObserver = false;
#endif // GEMMA_CC_ONCE #endif // GEMMA_CC_ONCE
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
@ -143,6 +146,10 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
MatStorageT<float>& x, ThreadingContext& ctx, MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) { size_t image_token_position = 0) {
static const auto zone =
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone);
// Image tokens just need to be copied. // Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM && if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 && image_tokens != nullptr && token == -2 &&
@ -295,11 +302,13 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
const WeightsPtrs& weights, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, Activations& activations, QBatch& qbatch,
MatMulEnv& env) { MatMulEnv& env) {
if (HWY_UNLIKELY(runtime_config.layers_output)) { if constexpr (kObserver) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { if (HWY_UNLIKELY(runtime_config.layers_output)) {
const float token_f = qbatch.PrevToken(qi); for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), const float token_f = qbatch.PrevToken(qi);
"tokens", -1, &token_f, 1); runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
"tokens", -1, &token_f, 1);
}
} }
} }
@ -313,10 +322,12 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env); activations, qbatch, env);
if (HWY_UNLIKELY(runtime_config.activations_observer)) { if constexpr (kObserver) {
runtime_config.activations_observer( if (HWY_UNLIKELY(runtime_config.activations_observer)) {
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, runtime_config.activations_observer(
activations); QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
} }
} }
} }
@ -403,23 +414,29 @@ static void SampleAndStream(
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
if (HWY_UNLIKELY(runtime_config.activations_observer)) { if constexpr (kObserver) {
runtime_config.activations_observer( if (HWY_UNLIKELY(runtime_config.activations_observer)) {
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
} }
{ {
PROFILER_ZONE("Gen.EmbeddingMatmul"); static const auto zone = env.ctx.profiler.AddZone(
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
// Compute logits from last layer activations. // Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding, CallMatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits); /*add=*/nullptr, env, activations.logits);
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
const size_t worker = 0; // TODO: parallelize
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
env.ctx);
// TODO: parallelize
non_eos.Foreach([&](size_t qi) { non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi); float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
env.ctx.profiler, worker);
const TokenAndProb tp = sample_token(logits, config.vocab_size); const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated(); timing_info.NotifyGenerated();

View File

@ -33,6 +33,7 @@
#include "util/mat.h" #include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/bit_set.h"
#include "hwy/contrib/sort/order.h" #include "hwy/contrib/sort/order.h"
#include "hwy/contrib/sort/vqsort.h" #include "hwy/contrib/sort/vqsort.h"
#include "hwy/detect_targets.h" #include "hwy/detect_targets.h"
@ -932,6 +933,17 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
} }
} }
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx) {
if (cap == 0.0f) return;
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) {
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker);
}
});
}
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
SampleArgmax(const float* probabilities, size_t vocab_size) { SampleArgmax(const float* probabilities, size_t vocab_size) {
size_t max_index = 0; size_t max_index = 0;