Support directly observing activations, partially replacing LayersOutputFunc

LayersOutputFunc is no longer invoked for "blocks" and "final_norm" outputs.
Instead, we directly expose the Activations structure.

PiperOrigin-RevId: 663409316
This commit is contained in:
Paul Chang 2024-08-15 12:38:37 -07:00 committed by Copybara-Service
parent 22995c699d
commit b9ed12a325
2 changed files with 29 additions and 30 deletions

View File

@ -748,13 +748,12 @@ class PrefillState {
// Generates one token for each query. `queries_token` is the previous token // Generates one token for each query. `queries_token` is the previous token
// from each query, and `queries_pos` are their position in the sequence. // from each query, and `queries_pos` are their position in the sequence.
template <class TConfig> template <class TConfig>
HWY_NOINLINE void Transformer(const QueriesToken& queries_token, HWY_NOINLINE void Transformer(
const QueriesMutablePos& queries_pos, const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights, Activations& activations,
Activations& activations, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
const hwy::Divisor& div_seq_len, hwy::ThreadPool& pool, const LayersOutputFunc& layers_output,
const KVCaches& kv_caches, hwy::ThreadPool& pool, const ActivationsObserverFunc& activations_observer) {
const LayersOutputFunc& layers_output) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
const size_t num_queries = queries_token.size(); const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries); HWY_DASSERT(queries_pos.size() == num_queries);
@ -778,24 +777,17 @@ HWY_NOINLINE void Transformer(const QueriesToken& queries_token,
layer_weights, activations, div_seq_len, layer_weights, activations, div_seq_len,
kv_caches, pool); kv_caches, pool);
if (layers_output) { if (activations_observer) {
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { activations_observer(queries_pos, layer, activations);
layers_output(query_idx, queries_pos[query_idx], "blocks", layer,
activations.x.Batch(0), kModelDim);
}
} }
} }
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(), RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
activations.x.All(), kModelDim); activations.x.All(), kModelDim);
if (layers_output) { if (activations_observer) {
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { activations_observer(queries_pos, -1, activations);
layers_output(query_idx, queries_pos[query_idx], "final_norm", -1,
activations.x.Batch(0), kModelDim);
} }
}
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
queries_pos[query_idx] += 1; queries_pos[query_idx] += 1;
} }
@ -970,7 +962,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Decode generates one token per query and increments queries_mutable_pos. // Decode generates one token per query and increments queries_mutable_pos.
Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries), Transformer<TConfig>(QueriesToken(gen_tokens.data(), num_queries),
queries_mutable_pos, weights, activations, div_seq_len, queries_mutable_pos, weights, activations, div_seq_len,
kv_caches, pool, runtime_config.layers_output); kv_caches, pool, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer. // queries_pos are incremented by Transformer.
bool all_queries_eos = true; bool all_queries_eos = true;

View File

@ -35,6 +35,14 @@
#include "hwy/base.h" // hwy::bfloat16_t #include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp { namespace gcpp {
using PromptTokens = hwy::Span<const int>;
// Batches of independent queries have their own prompt, previous token,
// position in the sequence, and KVCache.
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>;
using KVCaches = hwy::Span<KVCache>;
// StreamFunc is called with (token, probability). For prompt tokens, // StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and // probability is 0.0f. StreamFunc should return false to stop generation and
@ -53,12 +61,18 @@ using SampleFunc = std::function<int(const float*, size_t)>;
// If not empty, LayersOutputFunc is called for layer outputs, specified with: // If not empty, LayersOutputFunc is called for layer outputs, specified with:
// - index of query within containing batch (if any); zero otherwise. // - index of query within containing batch (if any); zero otherwise.
// - position in the tokens sequence // - position in the tokens sequence
// - name of the data, e.g. "tokens", "blocks", "final_norm" // - name of the data, e.g. "tokens" for token IDs
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer // - layer index (or -1 for global outputs)
// - pointer to the data array // - pointer to the data array
// - size of the data array // - size of the data array
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&, using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
int, const float*, size_t)>; int, const float*, size_t)>;
// If not empty, ActivationsObserverFunc is invoked after each layer with:
// - per-query position within the tokens sequence
// - layer index (or -1 for post-norm output)
// - activations
using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
struct RuntimeConfig { struct RuntimeConfig {
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
@ -85,6 +99,7 @@ struct RuntimeConfig {
AcceptFunc accept_token; // if empty, accepts all tokens. AcceptFunc accept_token; // if empty, accepts all tokens.
SampleFunc sample_func; // if empty, uses SampleTopK. SampleFunc sample_func; // if empty, uses SampleTopK.
LayersOutputFunc layers_output; // if not empty, called after each layer. LayersOutputFunc layers_output; // if not empty, called after each layer.
ActivationsObserverFunc activations_observer; // if set, called per-layer
int eos_id = EOS_ID; int eos_id = EOS_ID;
}; };
@ -141,15 +156,6 @@ struct TimingInfo {
size_t tokens_generated = 0; size_t tokens_generated = 0;
}; };
using PromptTokens = hwy::Span<const int>;
// Batches of independent queries have their own prompt, previous token,
// position in the sequence, and KVCache.
using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>;
using KVCaches = hwy::Span<KVCache>;
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,