mirror of https://github.com/google/gemma.cpp.git
1.01x speedup: More bf16 activations to reduce DecompressA.
Also move observer call into function, format gemma_args. PiperOrigin-RevId: 800827400
This commit is contained in:
parent
7288891439
commit
6c39a2dea4
|
|
@ -155,6 +155,7 @@ struct Activations {
|
||||||
: layer_config(config.layer_configs[0]),
|
: layer_config(config.layer_configs[0]),
|
||||||
|
|
||||||
x(MatFactory("x", batch_size, config.model_dim, allocator)),
|
x(MatFactory("x", batch_size, config.model_dim, allocator)),
|
||||||
|
x_bf(MatFactory("x_bf", batch_size, config.model_dim, allocator)),
|
||||||
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
|
logits(MatFactory("logits", batch_size, config.vocab_size, allocator)),
|
||||||
|
|
||||||
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size,
|
||||||
|
|
@ -173,6 +174,7 @@ struct Activations {
|
||||||
// If we forget any MatMul outputs here, debug builds print a warning but
|
// If we forget any MatMul outputs here, debug builds print a warning but
|
||||||
// fill them in each MatMul call.
|
// fill them in each MatMul call.
|
||||||
x.AllocateAndAttachRowPtrs(row_ptrs);
|
x.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
|
x_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
logits.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
C1.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
C2.AllocateAndAttachRowPtrs(row_ptrs);
|
||||||
|
|
@ -184,6 +186,7 @@ struct Activations {
|
||||||
// Negligible CPU time.
|
// Negligible CPU time.
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
x.OverrideRows(batch_size);
|
x.OverrideRows(batch_size);
|
||||||
|
x_bf.OverrideRows(batch_size);
|
||||||
logits.OverrideRows(batch_size);
|
logits.OverrideRows(batch_size);
|
||||||
|
|
||||||
pre_ffw_rms_out.OverrideRows(batch_size);
|
pre_ffw_rms_out.OverrideRows(batch_size);
|
||||||
|
|
@ -198,6 +201,7 @@ struct Activations {
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
|
|
||||||
MatStorageT<float> x; // input
|
MatStorageT<float> x; // input
|
||||||
|
MatStorageT<BF16> x_bf; // output of final RMSNorm, input to EmbeddingMatmul
|
||||||
MatStorageT<float> logits;
|
MatStorageT<float> logits;
|
||||||
|
|
||||||
// Gated FFW
|
// Gated FFW
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,18 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void MaybeObserve(const RuntimeConfig& runtime_config,
|
||||||
|
Activations& activations, QBatch& qbatch,
|
||||||
|
int layer_idx) {
|
||||||
|
if constexpr (kObserver) {
|
||||||
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
|
runtime_config.activations_observer(
|
||||||
|
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||||
|
activations);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||||
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
|
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
|
||||||
// token-batched `PrefillTBatch`, which supports image embedding.
|
// token-batched `PrefillTBatch`, which supports image embedding.
|
||||||
|
|
@ -322,13 +334,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||||
activations, qbatch, env);
|
activations, qbatch, env);
|
||||||
|
|
||||||
if constexpr (kObserver) {
|
MaybeObserve(runtime_config, activations, qbatch, layer_idx);
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
|
||||||
runtime_config.activations_observer(
|
|
||||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
|
||||||
activations);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -412,21 +418,17 @@ static void SampleAndStream(
|
||||||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||||
|
|
||||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf,
|
||||||
|
env.ctx);
|
||||||
|
|
||||||
if constexpr (kObserver) {
|
MaybeObserve(runtime_config, activations, qbatch, -1);
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
|
||||||
runtime_config.activations_observer(
|
|
||||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
static const auto zone = env.ctx.profiler.AddZone(
|
static const auto zone = env.ctx.profiler.AddZone(
|
||||||
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
|
"Gen.EmbeddingMatmul", hwy::ProfilerFlags::kInclusive);
|
||||||
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
|
PROFILER_ZONE3(env.ctx.profiler, /*worker=*/0, zone);
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
CallMatMul(activations.x, weights.embedder_input_embedding,
|
CallMatMul(activations.x_bf, weights.embedder_input_embedding,
|
||||||
/*add=*/nullptr, env, activations.logits);
|
/*add=*/nullptr, env, activations.logits);
|
||||||
}
|
}
|
||||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||||
|
|
|
||||||
|
|
@ -220,9 +220,11 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"resets every turn)");
|
"resets every turn)");
|
||||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||||
|
|
||||||
// Since it is not used in the CLI version, the print_verbosity is set higher than others.
|
// Since it is not used in the CLI version, the print_verbosity is set
|
||||||
|
// higher than others.
|
||||||
visitor(port, "port", 8080, "Server port (default: 8080)", 3);
|
visitor(port, "port", 8080, "Server port (default: 8080)", 3);
|
||||||
visitor(model, "model", std::string("gemma3-4b"), "Model name for API endpoints (default: gemma3-4b)", 3);
|
visitor(model, "model", std::string("gemma3-4b"),
|
||||||
|
"Model name for API endpoints (default: gemma3-4b)", 3);
|
||||||
|
|
||||||
visitor(prompt, "prompt", std::string(""),
|
visitor(prompt, "prompt", std::string(""),
|
||||||
"Initial prompt for non-interactive mode. When specified, "
|
"Initial prompt for non-interactive mode. When specified, "
|
||||||
|
|
@ -282,7 +284,8 @@ struct ClientArgs : public ArgsBase<ClientArgs> {
|
||||||
visitor(port, "port", 8080,
|
visitor(port, "port", 8080,
|
||||||
"Server port (default: 8080)");
|
"Server port (default: 8080)");
|
||||||
visitor(api_key, "api_key", std::string(""),
|
visitor(api_key, "api_key", std::string(""),
|
||||||
"Use public API with key (changes host to generativelanguage.googleapis.com:443)");
|
"Use public API with key (changes host to "
|
||||||
|
"generativelanguage.googleapis.com:443)");
|
||||||
visitor(model, "model", std::string("gemma3-4b"),
|
visitor(model, "model", std::string("gemma3-4b"),
|
||||||
"Model name to use (default: gemma3-4b)");
|
"Model name to use (default: gemma3-4b)");
|
||||||
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
visitor(prompt, "prompt", std::string("Hello! How are you?"),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue