mirror of https://github.com/google/gemma.cpp.git
1.01x speedup: More bf16 activations to reduce DecompressA.
Also move observer call into function, format gemma_args. PiperOrigin-RevId: 800827400
This commit is contained in:
parent
7288891439
commit
6c39a2dea4
|
|
@ -155,6 +155,7 @@ struct Activations {
|
|||
: layer_config(config.layer_configs[0]),
|
||||
|
||||
x(MatFactory("x", batch_size, config.model_dim, allocator)),
|
||||
x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)),
|
||||
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
|
||||
|
||||
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||
|
|
@ -173,6 +174,7 @@ struct Activations {
|
|||
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||
// fill them in each MatMul call.
|
||||
x.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
x_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
|
|
@ -184,6 +186,7 @@ struct Activations {
|
|||
// Negligible CPU time.
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
x.OverrideRows(batch_size);
|
||||
x_bf.OverrideRows(batch_size);
|
||||
logits.OverrideRows(batch_size);
|
||||
|
||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||
|
|
@ -198,6 +201,7 @@ struct Activations {
|
|||
const LayerConfig& layer_config;
|
||||
|
||||
MatStorageT<float> x; // input
|
||||
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||
MatStorageT<float> logits;
|
||||
|
||||
// Gated FFW
|
||||
|
|
|
|||
|
|
@ -294,6 +294,18 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
}
|
||||
}
|
||||
|
||||
static void MaybeObserve(const RuntimeConfig& runtime_config,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
int layer_idx) {
|
||||
if constexpr (kObserver) {
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||
activations);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
|
||||
// token-batched `PrefillTBatch`, which supports image embedding.
|
||||
|
|
@ -322,13 +334,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch, env);
|
||||
|
||||
if constexpr (kObserver) {
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||
activations);
|
||||
}
|
||||
}
|
||||
MaybeObserve(runtime_config, activations, qbatch, layer_idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -412,21 +418,17 @@ static void SampleAndStream(
|
|||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
|
||||
env.ctx);
|
||||
|
||||
if constexpr (kObserver) {
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
||||
}
|
||||
}
|
||||
MaybeObserve(runtime_config, activations, qbatch, -1);
|
||||
|
||||
{
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
|
||||
// Compute logits from last layer activations.
|
||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
||||
CallMatMul(activations.x_bf, weights.embedder_input_embedding,
|
||||
/*add=*/nullptr, env, activations.logits);
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
|
|
|
|||
|
|
@ -220,9 +220,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
|
||||
// Since it is not used in the CLI version, the print_verbosity is set higher than others.
|
||||
// Since it is not used in the CLI version, the print_verbosity is set
|
||||
// higher than others.
|
||||
visitor(port, "port", 8080, "Server port (default: 8080)", 3);
|
||||
visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3);
|
||||
visitor(model, "model", std::string("gemma3-4b"),
|
||||
"Model name for API endpoints (default: gemma3-4b)", 3);
|
||||
|
||||
visitor(prompt, "prompt", std::string(""),
|
||||
"Initial prompt for non-interactive mode. When specified, "
|
||||
|
|
@ -282,7 +284,8 @@ struct ClientArgs : public ArgsBase<ClientArgs> {
|
|||
visitor(port, "port", 8080,
|
||||
"Server port (default: 8080)");
|
||||
visitor(api_key, "api_key", std::string(""),
|
||||
"Use public API with key (changes host to generativelanguage.googleapis.com:443)");
|
||||
"Use public API with key (changes host to "
|
||||
"generativelanguage.googleapis.com:443)");
|
||||
visitor(model, "model", std::string("gemma3-4b"),
|
||||
"Model name to use (default: gemma3-4b)");
|
||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue