mirror of https://github.com/google/gemma.cpp.git
Support directly observing activations, partially replacing LayersOutputFunc
LayersOutputFunc is no longer invoked for "blocks" and "final_norm" outputs. Instead, we directly expose the Activations structure. PiperOrigin-RevId: 663409316
This commit is contained in:
parent
22995c699d
commit
b9ed12a325
|
|
@ -748,13 +748,12 @@ class PrefillState {
|
|||
// Generates one token for each query. `queries_token` is the previous token
|
||||
// from each query, and `queries_pos` are their position in the sequence.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Transformer(const QueriesToken& queries_token,
|
||||
const QueriesMutablePos& queries_pos,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
Activations& activations,
|
||||
const hwy::Divisor& div_seq_len,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
const LayersOutputFunc& layers_output) {
|
||||
HWY_NOINLINE void Transformer(
|
||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool, const LayersOutputFunc& layers_output,
|
||||
const ActivationsObserverFunc& activations_observer) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||
|
|
@ -778,24 +777,17 @@ HWY_NOINLINE void Transformer(const QueriesToken& queries_token,
|
|||
layer_weights, activations, div_seq_len,
|
||||
kv_caches, pool);
|
||||
|
||||
if (layers_output) {
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
layers_output(query_idx, queries_pos[query_idx], "blocks", layer,
|
||||
activations.x.Batch(0), kModelDim);
|
||||
}
|
||||
if (activations_observer) {
|
||||
activations_observer(queries_pos, layer, activations);
|
||||
}
|
||||
}
|
||||
|
||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
|
||||
activations.x.All(), kModelDim);
|
||||
|
||||
if (layers_output) {
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
layers_output(query_idx, queries_pos[query_idx], "final_norm", -1,
|
||||
activations.x.Batch(0), kModelDim);
|
||||
if (activations_observer) {
|
||||
activations_observer(queries_pos, -1, activations);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
queries_pos[query_idx] += 1;
|
||||
}
|
||||
|
|
@ -970,7 +962,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Decode generates one token per query and increments queries_mutable_pos.
|
||||
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, weights, activations, div_seq_len,
|
||||
kv_caches, pool, runtime_config.layers_output);
|
||||
kv_caches, pool, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
bool all_queries_eos = true;
|
||||
|
|
|
|||
|
|
@ -35,6 +35,14 @@
|
|||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
||||
namespace gcpp {
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||
|
|
@ -53,12 +61,18 @@ using SampleFunc = std::function<int(const float*, size_t)>;
|
|||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||
// - index of query within containing batch (if any); zero otherwise.
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, e.g. "tokens", "blocks", "final_norm"
|
||||
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer
|
||||
// - name of the data, e.g. "tokens" for token IDs
|
||||
// - layer index (or -1 for global outputs)
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||
int, const float*, size_t)>;
|
||||
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
||||
// - per-query position within the tokens sequence
|
||||
// - layer index (or -1 for post-norm output)
|
||||
// - activations
|
||||
using ActivationsObserverFunc =
|
||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
struct RuntimeConfig {
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
|
|
@ -85,6 +99,7 @@ struct RuntimeConfig {
|
|||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||
ActivationsObserverFunc activations_observer; // if set, called per-layer
|
||||
int eos_id = EOS_ID;
|
||||
};
|
||||
|
||||
|
|
@ -141,15 +156,6 @@ struct TimingInfo {
|
|||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
using PromptTokens = hwy::Span<const int>;
|
||||
|
||||
// Batches of independent queries have their own prompt, previous token,
|
||||
// position in the sequence, and KVCache.
|
||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||
using QueriesToken = hwy::Span<const int>;
|
||||
using QueriesPos = hwy::Span<const size_t>;
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
|
|
|
|||
Loading…
Reference in New Issue