1.01x speedup: More bf16 activations to reduce DecompressA.

Also move observer call into function, format gemma_args.

PiperOrigin-RevId: 800827400
This commit is contained in:
Jan Wassenberg 2025-08-29 03:18:28 -07:00 committed by Copybara-Service
parent 7288891439
commit 6c39a2dea4
3 changed files with 27 additions and 18 deletions

View File

@ -155,6 +155,7 @@ struct Activations {
: layer_config(config.layer_configs[0]),
x(MatFactory("x", batch_size, config.model_dim, allocator)),
x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)),
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
@ -173,6 +174,7 @@ struct Activations {
// If we forget any MatMul outputs here, debug builds print a warning but
// fill them in each MatMul call.
x.AllocateAndAttachRowPtrs(row_ptrs);
x_bf.AllocateAndAttachRowPtrs(row_ptrs);
logits.AllocateAndAttachRowPtrs(row_ptrs);
C1.AllocateAndAttachRowPtrs(row_ptrs);
C2.AllocateAndAttachRowPtrs(row_ptrs);
@ -184,6 +186,7 @@ struct Activations {
// Negligible CPU time.
void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size);
x_bf.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
@ -198,6 +201,7 @@ struct Activations {
const LayerConfig& layer_config;
MatStorageT<float> x; // input
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
MatStorageT<float> logits;
// Gated FFW

View File

@ -294,6 +294,18 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
}
}
static void MaybeObserve(const RuntimeConfig& runtime_config,
Activations& activations, QBatch& qbatch,
int layer_idx) {
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
}
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
// token-batched `PrefillTBatch`, which supports image embedding.
@ -322,13 +334,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env);
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
MaybeObserve(runtime_config, activations, qbatch, layer_idx);
}
}
@ -412,21 +418,17 @@ static void SampleAndStream(
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
env.ctx);
if constexpr (kObserver) {
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
}
MaybeObserve(runtime_config, activations, qbatch, -1);
{
static const auto zone = env.ctx.profiler.AddZone(
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
// Compute logits from last layer activations.
CallMatMul(activations.x, weights.embedder_input_embedding,
CallMatMul(activations.x_bf, weights.embedder_input_embedding,
/*add=*/nullptr, env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");

View File

@ -220,9 +220,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
// Since it is not used in the CLI version, the print_verbosity is set higher than others.
// Since it is not used in the CLI version, the print_verbosity is set
// higher than others.
visitor(port, "port", 8080, "Server port (default: 8080)", 3);
visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3);
visitor(model, "model", std::string("gemma3-4b"),
"Model name for API endpoints (default: gemma3-4b)", 3);
visitor(prompt, "prompt", std::string(""),
"Initial prompt for non-interactive mode. When specified, "
@ -282,7 +284,8 @@ struct ClientArgs : public ArgsBase<ClientArgs> {
visitor(port, "port", 8080,
"Server port (default: 8080)");
visitor(api_key, "api_key", std::string(""),
"Use public API with key (changes host to generativelanguage.googleapis.com:443)");
"Use public API with key (changes host to "
"generativelanguage.googleapis.com:443)");
visitor(model, "model", std::string("gemma3-4b"),
"Model name to use (default: gemma3-4b)");
visitor(prompt, "prompt", std::string("Hello! How are you?"),