1.04x speedup: Parallelize SoftCap

Also require opt-in constexpr flag for observer callbacks, update zones

PiperOrigin-RevId: 799655163
This commit is contained in:
Jan Wassenberg 2025-08-26 11:54:48 -07:00 committed by Copybara-Service
parent ed2f0bd1b0
commit 86afd53076
2 changed files with 45 additions and 16 deletions

View File

@ -61,6 +61,9 @@
#include "hwy/base.h"
#include "hwy/timer.h"
// Require opt-in to debug/introspection functions to eliminate their overhead.
HWY_INLINE_VAR constexpr bool kObserver = false;
#endif // GEMMA_CC_ONCE
HWY_BEFORE_NAMESPACE();
@ -143,6 +146,10 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
MatStorageT<float>& x, ThreadingContext& ctx,
const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
static const auto zone =
ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone);
// Image tokens just need to be copied.
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
@ -295,6 +302,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const float token_f = qbatch.PrevToken(qi);
@ -302,6 +310,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
"tokens", -1, &token_f, 1);
}
}
}
// TODO: parallelize?
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
@ -313,12 +322,14 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env);
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
}
}
// Populates KV cache for the batch queries, one token at a time.
@ -403,23 +414,29 @@ static void SampleAndStream(
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
}
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
static const auto zone = env.ctx.profiler.AddZone(
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
// Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
const size_t worker = 0; // TODO: parallelize
MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos,
env.ctx);
// TODO: parallelize
non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size,
env.ctx.profiler, worker);
const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();

View File

@ -33,6 +33,7 @@
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/bit_set.h"
#include "hwy/contrib/sort/order.h"
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/detect_targets.h"
@ -932,6 +933,17 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
}
}
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx) {
if (cap == 0.0f) return;
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) {
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker);
}
});
}
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
SampleArgmax(const float* probabilities, size_t vocab_size) {
size_t max_index = 0;