mirror of https://github.com/google/gemma.cpp.git
Support directly observing activations, partially replacing LayersOutputFunc
LayersOutputFunc is no longer invoked for "blocks" and "final_norm" outputs. Instead, we directly expose the Activations structure. PiperOrigin-RevId: 663409316
This commit is contained in:
parent
22995c699d
commit
b9ed12a325
|
|
@ -748,13 +748,12 @@ class PrefillState {
|
||||||
// Generates one token for each query. `queries_token` is the previous token
|
// Generates one token for each query. `queries_token` is the previous token
|
||||||
// from each query, and `queries_pos` are their position in the sequence.
|
// from each query, and `queries_pos` are their position in the sequence.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Transformer(const QueriesToken& queries_token,
|
HWY_NOINLINE void Transformer(
|
||||||
const QueriesMutablePos& queries_pos,
|
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||||
const CompressedWeights<TConfig>& weights,
|
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||||
Activations& activations,
|
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
|
||||||
const hwy::Divisor& div_seq_len,
|
hwy::ThreadPool& pool, const LayersOutputFunc& layers_output,
|
||||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
const ActivationsObserverFunc& activations_observer) {
|
||||||
const LayersOutputFunc& layers_output) {
|
|
||||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
const size_t num_queries = queries_token.size();
|
const size_t num_queries = queries_token.size();
|
||||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
HWY_DASSERT(queries_pos.size() == num_queries);
|
||||||
|
|
@ -778,24 +777,17 @@ HWY_NOINLINE void Transformer(const QueriesToken& queries_token,
|
||||||
layer_weights, activations, div_seq_len,
|
layer_weights, activations, div_seq_len,
|
||||||
kv_caches, pool);
|
kv_caches, pool);
|
||||||
|
|
||||||
if (layers_output) {
|
if (activations_observer) {
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
activations_observer(queries_pos, layer, activations);
|
||||||
layers_output(query_idx, queries_pos[query_idx], "blocks", layer,
|
|
||||||
activations.x.Batch(0), kModelDim);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
|
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
|
||||||
activations.x.All(), kModelDim);
|
activations.x.All(), kModelDim);
|
||||||
|
|
||||||
if (layers_output) {
|
if (activations_observer) {
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
activations_observer(queries_pos, -1, activations);
|
||||||
layers_output(query_idx, queries_pos[query_idx], "final_norm", -1,
|
|
||||||
activations.x.Batch(0), kModelDim);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
queries_pos[query_idx] += 1;
|
queries_pos[query_idx] += 1;
|
||||||
}
|
}
|
||||||
|
|
@ -970,7 +962,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
// Decode generates one token per query and increments queries_mutable_pos.
|
// Decode generates one token per query and increments queries_mutable_pos.
|
||||||
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
|
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
|
||||||
queries_mutable_pos, weights, activations, div_seq_len,
|
queries_mutable_pos, weights, activations, div_seq_len,
|
||||||
kv_caches, pool, runtime_config.layers_output);
|
kv_caches, pool, runtime_config.layers_output,
|
||||||
|
runtime_config.activations_observer);
|
||||||
// queries_pos are incremented by Transformer.
|
// queries_pos are incremented by Transformer.
|
||||||
|
|
||||||
bool all_queries_eos = true;
|
bool all_queries_eos = true;
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,14 @@
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
using PromptTokens = hwy::Span<const int>;
|
||||||
|
|
||||||
|
// Batches of independent queries have their own prompt, previous token,
|
||||||
|
// position in the sequence, and KVCache.
|
||||||
|
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
||||||
|
using QueriesToken = hwy::Span<const int>;
|
||||||
|
using QueriesPos = hwy::Span<const size_t>;
|
||||||
|
using KVCaches = hwy::Span<KVCache>;
|
||||||
|
|
||||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||||
// probability is 0.0f. StreamFunc should return false to stop generation and
|
// probability is 0.0f. StreamFunc should return false to stop generation and
|
||||||
|
|
@ -53,12 +61,18 @@ using SampleFunc = std::function<int(const float*, size_t)>;
|
||||||
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
// If not empty, LayersOutputFunc is called for layer outputs, specified with:
|
||||||
// - index of query within containing batch (if any); zero otherwise.
|
// - index of query within containing batch (if any); zero otherwise.
|
||||||
// - position in the tokens sequence
|
// - position in the tokens sequence
|
||||||
// - name of the data, e.g. "tokens", "blocks", "final_norm"
|
// - name of the data, e.g. "tokens" for token IDs
|
||||||
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer
|
// - layer index (or -1 for global outputs)
|
||||||
// - pointer to the data array
|
// - pointer to the data array
|
||||||
// - size of the data array
|
// - size of the data array
|
||||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||||
int, const float*, size_t)>;
|
int, const float*, size_t)>;
|
||||||
|
// If not empty, ActivationsObserverFunc is invoked after each layer with:
|
||||||
|
// - per-query position within the tokens sequence
|
||||||
|
// - layer index (or -1 for post-norm output)
|
||||||
|
// - activations
|
||||||
|
using ActivationsObserverFunc =
|
||||||
|
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||||
|
|
@ -85,6 +99,7 @@ struct RuntimeConfig {
|
||||||
AcceptFunc accept_token; // if empty, accepts all tokens.
|
AcceptFunc accept_token; // if empty, accepts all tokens.
|
||||||
SampleFunc sample_func; // if empty, uses SampleTopK.
|
SampleFunc sample_func; // if empty, uses SampleTopK.
|
||||||
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
LayersOutputFunc layers_output; // if not empty, called after each layer.
|
||||||
|
ActivationsObserverFunc activations_observer; // if set, called per-layer
|
||||||
int eos_id = EOS_ID;
|
int eos_id = EOS_ID;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -141,15 +156,6 @@ struct TimingInfo {
|
||||||
size_t tokens_generated = 0;
|
size_t tokens_generated = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
using PromptTokens = hwy::Span<const int>;
|
|
||||||
|
|
||||||
// Batches of independent queries have their own prompt, previous token,
|
|
||||||
// position in the sequence, and KVCache.
|
|
||||||
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
|
|
||||||
using QueriesToken = hwy::Span<const int>;
|
|
||||||
using QueriesPos = hwy::Span<const size_t>;
|
|
||||||
using KVCaches = hwy::Span<KVCache>;
|
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue