diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index a7d7357..101adab 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -748,13 +748,12 @@ class PrefillState { // Generates one token for each query. `queries_token` is the previous token // from each query, and `queries_pos` are their position in the sequence. template -HWY_NOINLINE void Transformer(const QueriesToken& queries_token, - const QueriesMutablePos& queries_pos, - const CompressedWeights& weights, - Activations& activations, - const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches, hwy::ThreadPool& pool, - const LayersOutputFunc& layers_output) { +HWY_NOINLINE void Transformer( + const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, + const CompressedWeights& weights, Activations& activations, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + hwy::ThreadPool& pool, const LayersOutputFunc& layers_output, + const ActivationsObserverFunc& activations_observer) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_queries = queries_token.size(); HWY_DASSERT(queries_pos.size() == num_queries); @@ -778,24 +777,17 @@ HWY_NOINLINE void Transformer(const QueriesToken& queries_token, layer_weights, activations, div_seq_len, kv_caches, pool); - if (layers_output) { - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - layers_output(query_idx, queries_pos[query_idx], "blocks", layer, - activations.x.Batch(0), kModelDim); - } + if (activations_observer) { + activations_observer(queries_pos, layer, activations); } } RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(), activations.x.All(), kModelDim); - if (layers_output) { - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - layers_output(query_idx, queries_pos[query_idx], "final_norm", -1, - activations.x.Batch(0), kModelDim); - } + if (activations_observer) { + activations_observer(queries_pos, -1, activations); } - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { queries_pos[query_idx] += 1; } @@ -970,7 +962,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Decode generates one token per query and increments queries_mutable_pos. Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, weights, activations, div_seq_len, - kv_caches, pool, runtime_config.layers_output); + kv_caches, pool, runtime_config.layers_output, + runtime_config.activations_observer); // queries_pos are incremented by Transformer. bool all_queries_eos = true; diff --git a/gemma/gemma.h b/gemma/gemma.h index ba4dce3..70f3280 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -35,6 +35,14 @@ #include "hwy/base.h" // hwy::bfloat16_t namespace gcpp { +using PromptTokens = hwy::Span; + +// Batches of independent queries have their own prompt, previous token, +// position in the sequence, and KVCache. +using QueriesPromptTokens = hwy::Span; +using QueriesToken = hwy::Span; +using QueriesPos = hwy::Span; +using KVCaches = hwy::Span; // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. StreamFunc should return false to stop generation and @@ -53,12 +61,18 @@ using SampleFunc = std::function; // If not empty, LayersOutputFunc is called for layer outputs, specified with: // - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence -// - name of the data, e.g. "tokens", "blocks", "final_norm" -// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer +// - name of the data, e.g. "tokens" for token IDs +// - layer index (or -1 for global outputs) // - pointer to the data array // - size of the data array using LayersOutputFunc = std::function; +// If not empty, ActivationsObserverFunc is invoked after each layer with: +// - per-query position within the tokens sequence +// - layer index (or -1 for post-norm output) +// - activations +using ActivationsObserverFunc = + std::function; struct RuntimeConfig { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { @@ -85,6 +99,7 @@ struct RuntimeConfig { AcceptFunc accept_token; // if empty, accepts all tokens. SampleFunc sample_func; // if empty, uses SampleTopK. LayersOutputFunc layers_output; // if not empty, called after each layer. + ActivationsObserverFunc activations_observer; // if set, called per-layer int eos_id = EOS_ID; }; @@ -141,15 +156,6 @@ struct TimingInfo { size_t tokens_generated = 0; }; -using PromptTokens = hwy::Span; - -// Batches of independent queries have their own prompt, previous token, -// position in the sequence, and KVCache. -using QueriesPromptTokens = hwy::Span; -using QueriesToken = hwy::Span; -using QueriesPos = hwy::Span; -using KVCaches = hwy::Span; - class Gemma { public: Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,