diff --git a/evals/debug_prompt.cc b/evals/debug_prompt.cc index 8a4cc8f..7ea32fa 100644 --- a/evals/debug_prompt.cc +++ b/evals/debug_prompt.cc @@ -59,10 +59,13 @@ int Run(int argc, char** argv) { env.MutableConfig().layers_output = prompt_args.layers_output.Empty() ? LayersOutputFunc() - : [&json_output](int pos, const std::string& key, const float* values, - size_t values_len) { - std::vector v{values, values + values_len}; - json_output[std::to_string(pos)][key] = v; + : [&json_output](size_t query_idx, size_t pos, const std::string& key, + int layer, const float* values, size_t values_len) { + const std::string& debug_key = + layer < 0 ? key : (key + "." + std::to_string(layer)); + const std::vector v{values, values + values_len}; + json& json_base = json_output[std::to_string(query_idx)]; + json_base[std::to_string(pos)][debug_key] = v; }; const auto [answer, token_count] = env.QueryModel(prompt_args.prompt); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0613b92..9992981 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -805,8 +805,10 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, const size_t num_interleaved = num_tokens * num_queries; if (layers_output) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { - float token_f = tokens[token_idx]; - layers_output(pos + token_idx, "Tokens", &token_f, 1); + const size_t query_idx = token_idx % num_queries; + const size_t logical_pos = (pos + token_idx) / num_queries; + const float token_f = tokens[token_idx]; + layers_output(query_idx, logical_pos, "tokens", -1, &token_f, 1); } } constexpr size_t kModelDim = TConfig::kModelDim; @@ -821,9 +823,9 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, layer_weights, activations, kv_caches, pool); if (layers_output) { - const std::string block_name = "blocks." + std::to_string(layer); for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { - layers_output(pos + token_idx, block_name, + const size_t logical_pos = (pos + token_idx) / num_queries; + layers_output(token_idx % num_queries, logical_pos, "blocks", layer, activations.x.Batch(token_idx), kModelDim); } } @@ -833,7 +835,9 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, activations.x.All(), kModelDim); if (layers_output) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { - layers_output(pos + token_idx, "final_norm", + const size_t query_idx = token_idx % num_queries; + const size_t logical_pos = (pos + token_idx) / num_queries; + layers_output(query_idx, logical_pos, "final_norm", -1, activations.x.Batch(token_idx), kModelDim); } } diff --git a/gemma/gemma.h b/gemma/gemma.h index e734456..081e783 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -49,13 +49,15 @@ using AcceptFunc = std::function; // If not empty, SampleFunc is called with the probability distribution for the // next token, and its return value is used as the next generated token. using SampleFunc = std::function; -// Will be called for layers output with: +// If not empty, LayersOutputFunc is called for layer outputs, specified with: +// - index of query within containing batch (if any); zero otherwise. // - position in the tokens sequence -// - name of the data, p.ex. "tokens", "block.1", "final_norm" +// - name of the data, e.g. "tokens", "blocks", "final_norm" +// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer // - pointer to the data array // - size of the data array using LayersOutputFunc = - std::function; + std::function; struct RuntimeConfig { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {