1.01x speedup: More bf16 activations to reduce DecompressA.

Also move observer call into function, format gemma_args.

PiperOrigin-RevId: 800827400
This commit is contained in:
Jan Wassenberg 2025-08-29 03:18:28 -07:00 committed by Copybara-Service
parent 7288891439
commit 6c39a2dea4
3 changed files with 27 additions and 18 deletions

View File

@ -155,6 +155,7 @@ struct Activations {
: layer_config(config.layer_configs[0]), : layer_config(config.layer_configs[0]),
x(MatFactory("x", batch_size, config.model_dim, allocator)), x(MatFactory("x", batch_size, config.model_dim, allocator)),
x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)),
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
@ -173,6 +174,7 @@ struct Activations {
// If we forget any MatMul outputs here, debug builds print a warning but // If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call. // fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs); x.AllocateAndAttachRowPtrs(row_ptrs);
x_bf.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs); logits.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs);
@ -184,6 +186,7 @@ struct Activations {
// Negligible CPU time. // Negligible CPU time.
void SetBatchSize(size_t batch_size) { void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size); x.OverrideRows(batch_size);
x_bf.OverrideRows(batch_size);
logits.OverrideRows(batch_size); logits.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size);
@ -198,6 +201,7 @@ struct Activations {
const LayerConfig& layer_config; const LayerConfig& layer_config;
MatStorageT<float> x; // input MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits; MatStorageT<float> logits;
// Gated FFW // Gated FFW

View File

@ -294,6 +294,18 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
} }
} }
static void MaybeObserve(const RuntimeConfig& runtime_config,
Activations& activations, QBatch& qbatch,
int layer_idx) {
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
}
// Embeds PrevToken (one from each query) and calls each TransformerLayer. // Embeds PrevToken (one from each query) and calls each TransformerLayer.
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the // Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
// token-batched `PrefillTBatch`, which supports image embedding. // token-batched `PrefillTBatch`, which supports image embedding.
@ -322,13 +334,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env); activations, qbatch, env);
if constexpr (kObserver) { MaybeObserve(runtime_config, activations, qbatch, layer_idx);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
} }
} }
@ -412,21 +418,17 @@ static void SampleAndStream(
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows()); HWY_DASSERT(qbatch.Size() == activations.x.Rows());
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
env.ctx);
if constexpr (kObserver) { MaybeObserve(runtime_config, activations, qbatch, -1);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
}
{ {
static const auto zone = env.ctx.profiler.AddZone( static const auto zone = env.ctx.profiler.AddZone(
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive); "Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone); PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
// Compute logits from last layer activations. // Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding, CallMatMul(activations.x_bf, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits); /*add=*/nullptr, env, activations.logits);
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");

View File

@ -220,9 +220,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"resets every turn)"); "resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load."); visitor(image_file, "image_file", Path(), "Image file to load.");
// Since it is not used in the CLI version, the print_verbosity is set higher than others. // Since it is not used in the CLI version, the print_verbosity is set
// higher than others.
visitor(port, "port", 8080, "Server port (default: 8080)", 3); visitor(port, "port", 8080, "Server port (default: 8080)", 3);
visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3); visitor(model, "model", std::string("gemma3-4b"),
"Model name for API endpoints (default: gemma3-4b)", 3);
visitor(prompt, "prompt", std::string(""), visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, " "Initial prompt for non-interactive mode. When specified, "
@ -282,7 +284,8 @@ struct ClientArgs : public ArgsBase<ClientArgs> {
visitor(port, "port", 8080, visitor(port, "port", 8080,
"Server port (default: 8080)"); "Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""), visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to generativelanguage.googleapis.com:443)"); "Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"), visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)"); "Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"), visitor(prompt, "prompt", std::string("Hello! How are you?"),