diff --git a/gemma/activations.h b/gemma/activations.h index 877afdf..175ddc9 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -155,6 +155,7 @@ struct Activations { : layer_config(config.layer_configs[0]), x(MatFactory("x", batch_size, config.model_dim, allocator)), + x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)), logits(MatFactory("logits", batch_size, config.vocab_size, allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, @@ -173,6 +174,7 @@ struct Activations { // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. x.AllocateAndAttachRowPtrs(row_ptrs); + x_bf.AllocateAndAttachRowPtrs(row_ptrs); logits.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs); @@ -184,6 +186,7 @@ struct Activations { // Negligible CPU time. void SetBatchSize(size_t batch_size) { x.OverrideRows(batch_size); + x_bf.OverrideRows(batch_size); logits.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size); @@ -198,6 +201,7 @@ struct Activations { const LayerConfig& layer_config; MatStorageT x; // input + MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; // Gated FFW diff --git a/gemma/gemma.cc b/gemma/gemma.cc index c6da86c..fc1f238 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -294,6 +294,18 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, } } +static void MaybeObserve(const RuntimeConfig& runtime_config, + Activations& activations, QBatch& qbatch, + int layer_idx) { + if constexpr (kObserver) { + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, + activations); + } + } +} + // Embeds PrevToken (one from each query) and calls each TransformerLayer. // Called by query-batched `PrefillQBatch` and `GenerateT`, but not the // token-batched `PrefillTBatch`, which supports image embedding. @@ -322,13 +334,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), activations, qbatch, env); - if constexpr (kObserver) { - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, - activations); - } - } + MaybeObserve(runtime_config, activations, qbatch, layer_idx); } } @@ -412,21 +418,17 @@ static void SampleAndStream( hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); - RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx); + RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf, + env.ctx); - if constexpr (kObserver) { - if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer( - QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); - } - } + MaybeObserve(runtime_config, activations, qbatch, -1); { static const auto zone = env.ctx.profiler.AddZone( "Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive); PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone); // Compute logits from last layer activations. - CallMatMul(activations.x, weights.embedder_input_embedding, + CallMatMul(activations.x_bf, weights.embedder_input_embedding, /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 469ba2a..16c9595 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -220,9 +220,11 @@ struct InferenceArgs : public ArgsBase { "resets every turn)"); visitor(image_file, "image_file", Path(), "Image file to load."); - // Since it is not used in the CLI version, the print_verbosity is set higher than others. + // Since it is not used in the CLI version, the print_verbosity is set + // higher than others. visitor(port, "port", 8080, "Server port (default: 8080)", 3); - visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3); + visitor(model, "model", std::string("gemma3-4b"), + "Model name for API endpoints (default: gemma3-4b)", 3); visitor(prompt, "prompt", std::string(""), "Initial prompt for non-interactive mode. When specified, " @@ -282,7 +284,8 @@ struct ClientArgs : public ArgsBase { visitor(port, "port", 8080, "Server port (default: 8080)"); visitor(api_key, "api_key", std::string(""), - "Use public API with key (changes host to generativelanguage.googleapis.com:443)"); + "Use public API with key (changes host to " + "generativelanguage.googleapis.com:443)"); visitor(model, "model", std::string("gemma3-4b"), "Model name to use (default: gemma3-4b)"); visitor(prompt, "prompt", std::string("Hello! How are you?"),