Major refactor: clarify query_idx (global) vs qi. Refs #607

Fix missing pos increment for last prefill and check that in gemma_test.
Thanks to @ufownl for pointing this out.

Change argument lists to QBatch with accessors.
Increase default seq_len to 8k.

PiperOrigin-RevId: 771937385
This commit is contained in:
Jan Wassenberg 2025-06-16 02:41:30 -07:00 committed by Copybara-Service
parent 2c72ff2aa5
commit e5c81f64a1
15 changed files with 399 additions and 416 deletions

View File

@ -109,16 +109,19 @@ void GemmaEnv::QueryModel(
} }
std::vector<QueryResult> GemmaEnv::BatchQueryModel( std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) { const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0); HWY_ASSERT(num_queries != 0);
std::vector<QueryResult> res(num_queries); std::vector<QueryResult> res(num_queries);
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this]( const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index,
size_t query_index, size_t pos, const size_t pos,
int token, float) { const int token, float) {
HWY_ASSERT(query_index < num_queries);
std::string token_text; std::string token_text;
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text)); HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].response.append(token_text); res[query_index].response.append(token_text);
HWY_ASSERT(pos == res[query_index].tokens_generated);
res[query_index].tokens_generated += 1; res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated == if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) { queries_prompt[query_index].size()) {
@ -126,6 +129,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
} }
return true; return true;
}; };
runtime_config_.batch_stream_token = batch_stream_token;
if (runtime_config_.verbosity >= 2) { if (runtime_config_.verbosity >= 2) {
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature, runtime_config_.max_generated_tokens, runtime_config_.temperature,
@ -137,13 +141,11 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
while (kv_caches_.size() < num_queries) { while (kv_caches_.size() < num_queries) {
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference())); kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
} }
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token; gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
std::vector<size_t> queries_pos(num_queries, 0);
gemma_.GenerateBatch(runtime_config_, queries_prompt,
QueriesPos(queries_pos.data(), num_queries),
KVCaches(&kv_caches_[0], num_queries), timing_info);
return res; return res;
} }

View File

@ -88,8 +88,10 @@ class GemmaEnv {
// Runs inference on the given input and returns the top-1 result string and // Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated. // the number of tokens that were generated.
QueryResult QueryModel(const std::vector<int>& tokens); QueryResult QueryModel(const std::vector<int>& tokens);
// The default prefix_end means "causal attention".
std::vector<QueryResult> BatchQueryModel( std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt); const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload. // Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(std::string& input); QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel( std::vector<QueryResult> BatchQueryModel(

View File

@ -101,9 +101,11 @@ TEST_F(GemmaTest, Multiturn) {
const ModelConfig& config = model->GetModelConfig(); const ModelConfig& config = model->GetModelConfig();
size_t abs_pos = 0; size_t abs_pos = 0;
std::string response; std::string response;
auto stream_token = [&](int token, float) { auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
if (config.IsEOS(token)) return true; HWY_ASSERT(query_idx == 0);
HWY_ASSERT(pos == abs_pos);
++abs_pos; ++abs_pos;
if (config.IsEOS(token)) return true;
std::string token_text; std::string token_text;
EXPECT_TRUE( EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text)); model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
@ -115,7 +117,7 @@ TEST_F(GemmaTest, Multiturn) {
.temperature = 0.0f, .temperature = 0.0f,
.gen = &s_env->MutableGen(), .gen = &s_env->MutableGen(),
.verbosity = 2, .verbosity = 2,
.stream_token = stream_token, .batch_stream_token = stream_token,
}; };
TimingInfo timing_info{.verbosity = 0}; TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual. // First "say" something slightly unusual.

View File

@ -42,13 +42,12 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
} }
struct Activations { struct Activations {
Activations(const ModelConfig& config, size_t batch_size, Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: weights_config(config), : weights_config(config),
layer_config(config.layer_configs[0]), layer_config(config.layer_configs[0]),
div_seq_len(static_cast<uint32_t>(config.max_seq_len)), div_seq_len(static_cast<uint32_t>(seq_len)),
is_griffin(config.model == Model::GRIFFIN_2B), is_griffin(config.model == Model::GRIFFIN_2B),
query_scale(ChooseQueryScale(config)),
x("x", Extents2D(batch_size, config.model_dim), pad_), x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
@ -63,10 +62,7 @@ struct Activations {
pre_att_rms_out("pre_att_rms_out", pre_att_rms_out("pre_att_rms_out",
Extents2D(batch_size, config.model_dim), pad_), Extents2D(batch_size, config.model_dim), pad_),
att("att", att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_),
Extents2D(batch_size,
layer_config.heads * div_seq_len.GetDivisor()),
pad_),
att_out( att_out(
"att_out", "att_out",
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
@ -99,7 +95,7 @@ struct Activations {
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)), 1000000.0)),
gen_tokens(batch_size) { query_scale(ChooseQueryScale(config)) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers. // For MatMul outputs, precompute their row pointers.
@ -138,8 +134,6 @@ struct Activations {
griffin_gate_x.OverrideRows(batch_size); griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size); griffin_multiplier.OverrideRows(batch_size);
} }
gen_tokens.resize(batch_size);
} }
bool IsGlobalLayer(size_t layer_idx) const { bool IsGlobalLayer(size_t layer_idx) const {
@ -151,7 +145,6 @@ struct Activations {
const LayerConfig& layer_config; const LayerConfig& layer_config;
hwy::Divisor div_seq_len; hwy::Divisor div_seq_len;
bool is_griffin; bool is_griffin;
float query_scale;
const Extents2D none_ = Extents2D(); const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd; const MatPadding pad_ = MatPadding::kOdd;
@ -182,9 +175,7 @@ struct Activations {
MatStorageT<float> inv_timescale; MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global; MatStorageT<float> inv_timescale_global;
// Storage for the last generated token from each query, passed to the next float query_scale;
// Transformer() call.
std::vector<int> gen_tokens; // one per query in the batch
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -153,15 +153,12 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
return pos - HWY_MIN(att_window_size - 1, pos); return pos - HWY_MIN(att_window_size - 1, pos);
} }
void DotSoftmaxWeightedSum(const size_t num_tokens, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const size_t layer_idx,
const LayerWeightsPtrs& layer, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches, Activations& activations, QBatch& qbatch,
NestedPools& pools) { NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.DotSoftmax"); PROFILER_ZONE("Gen.Attention.DotSoftmax");
const hwy::Divisor div_queries(queries_pos.size()); const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
@ -176,7 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
// For each head/token/query, compute Q.K, softmax, and weighted V. // For each head/token/query, compute Q.K, softmax, and weighted V.
// Statically partition token/query across packages. // Statically partition token/query across packages.
const size_t num_tq = num_tokens * div_queries.GetDivisor(); const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
const IndexRangePartition tq_ranges = const IndexRangePartition tq_ranges =
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1); StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
ParallelizeOneRange( ParallelizeOneRange(
@ -185,17 +182,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
pools.AllClusters(pkg_idx).Run( pools.AllClusters(pkg_idx).Run(
tq_range.begin(), tq_range.end(), tq_range.begin(), tq_range.end(),
[&](const size_t tq_idx, const size_t cluster_idx) { [&](const size_t tq_idx, const size_t cluster_idx) {
const size_t query_idx = div_queries.Remainder(tq_idx); const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_queries.Divide(tq_idx); const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = kv_caches[query_idx].kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
// Find the token position in the query and calculate // Find the token position in the query and calculate
// the range of cache positions to attend to. // the range of cache positions to attend to.
const size_t pos = queries_pos[query_idx] + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos = const size_t start_pos =
StartPos(pos, activations.weights_config, layer_idx); StartPos(pos, activations.weights_config, layer_idx);
size_t last_pos = pos; size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end[query_idx]; const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) { if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive. // last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1; last_pos = prefix_end - 1;
@ -235,14 +232,21 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
}); });
} }
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
// Fills activations.q and writes to KV cache. // Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV( static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx, const LayerWeightsPtrs& layer,
const LayerWeightsPtrs& layer, Activations& activations, Activations& activations,
const KVCaches& kv_caches, const int flags, MatMulEnv& env) { const QBatch& qbatch, const int flags,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV"); PROFILER_ZONE("Gen.Attention.QKV");
const hwy::Divisor div_queries(queries_pos.size()); const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_queries.GetDivisor(); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads; const size_t kv_heads = layer_config.kv_heads;
@ -260,13 +264,12 @@ static HWY_INLINE void ComputeQKV(
layer.qkv_einsum_w2.Rows())); layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
const size_t query_idx = div_queries.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_queries.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos = const size_t cache_pos =
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>( env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
kv_caches[query_idx].kv_cache.Row(cache_pos) + qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
layer_idx * cache_layer_size);
} }
kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
@ -280,11 +283,11 @@ static HWY_INLINE void ComputeQKV(
[&](uint64_t task, size_t /*thread*/) HWY_ATTR { [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads; const size_t interleaved_idx = task / kv_heads;
const size_t query_idx = div_queries.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_queries.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos); const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = kv_caches[query_idx].kv_cache; auto& kv_cache = qbatch.KV(qi).kv_cache;
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size + layer_idx * cache_layer_size +
head * qkv_dim * 2; head * qkv_dim * 2;
@ -320,35 +323,18 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
activations.att_sums); activations.att_sums);
} }
// `queries_prefix_end` can be null (interpreted as all-zero) for standard void GemmaAttention(size_t num_tokens, const size_t layer_idx,
// causal attention, and must be non-null for prefix-LM style attention. const LayerWeightsPtrs& layer, Activations& activations,
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, QBatch& qbatch, MatMulEnv& env, int flags) {
const QueriesPos* queries_prefix_end,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches,
MatMulEnv& env, int flags) {
const size_t num_queries = queries_pos.size();
HWY_DASSERT(num_queries <= kv_caches.size());
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
"query heads must be a multiple of key-value heads"); "query heads must be a multiple of key-value heads");
(void)layer_config; // only used in HWY_DASSERT (void)layer_config; // only used in HWY_DASSERT
std::vector<size_t> queries_prefix_end_vec; ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
QueriesPos queries_prefix_end_span; DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
if (queries_prefix_end == nullptr) { env.ctx.pools);
queries_prefix_end_vec.assign(num_queries, 0);
queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(),
queries_prefix_end_vec.size());
queries_prefix_end = &queries_prefix_end_span;
}
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
flags, env);
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
layer, activations, kv_caches, env.ctx.pools);
SumHeads(layer, activations, env); SumHeads(layer, activations, env);
} }

View File

@ -35,18 +35,14 @@ namespace gcpp {
const Activations& activations, float* HWY_RESTRICT att, \ const Activations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \ float* HWY_RESTRICT att_out); \
\ \
void DotSoftmaxWeightedSum(const size_t num_tokens, \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const QueriesPos& queries_pos, \ const LayerWeightsPtrs& layer, \
const QueriesPos& queries_prefix_end, \ Activations& activations, QBatch& qbatch, \
size_t layer_idx, const LayerWeightsPtrs& layer, \ NestedPools& pools); \
Activations& activations, \
const KVCaches& kv_caches, NestedPools& pools); \
\ \
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const QueriesPos* queries_prefix_end, \ const LayerWeightsPtrs& layer, Activations& activations, \
const size_t layer_idx, const LayerWeightsPtrs& layer, \ QBatch& qbatch, MatMulEnv& env, int flags); \
Activations& activations, const KVCaches& kv_caches, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE

View File

@ -205,7 +205,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// RuntimeConfig runtime_config = { ... }; // This was already defined // RuntimeConfig runtime_config = { ... }; // This was already defined
double image_tokens_start = hwy::platform::Now(); double image_tokens_start = hwy::platform::Now();
// Pass the populated image object to GenerateImageTokens // Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config, image, image_tokens); model.GenerateImageTokens(runtime_config,
active_conversation->kv_cache->SeqLen(), image,
image_tokens);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start; double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
ss.str(""); ss.str("");

View File

@ -45,7 +45,6 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/model_store.h" #include "gemma/model_store.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "io/blob_store.h" #include "io/blob_store.h"
#include "io/io.h" // Path #include "io/io.h" // Path
@ -62,14 +61,11 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, size_t num_tokens, void Attention(LayerAttentionType type, const size_t num_tokens,
const QueriesPos& queries_pos, const size_t layer_idx, const LayerWeightsPtrs& layer,
const QueriesPos& queries_prefix_end, const size_t layer_idx, Activations& activations, QBatch& qbatch, MatMulEnv& env) {
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx, GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env,
layer, activations, kv_caches, env,
/*flags=*/0); /*flags=*/0);
} else { } else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
@ -77,23 +73,23 @@ void Attention(LayerAttentionType type, size_t num_tokens,
// so map `layer` to the Griffin layer index. // so map `layer` to the Griffin layer index.
const size_t griffin_layer = const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
&layer, kv_caches, env); env);
} }
} }
static HWY_NOINLINE void TransformerLayer( static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
const size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
const QueriesPos& queries_prefix_end, const size_t layer_idx, const LayerWeightsPtrs& layer,
const LayerWeightsPtrs& layer, Activations& activations, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) { QBatch& qbatch, MatMulEnv& env) {
const LayerConfig& layer_config = layer.layer_config; const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale, RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out); activations.pre_att_rms_out);
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end, Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
layer_idx, layer, activations, kv_caches, env); qbatch, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums); activations.att_sums);
@ -134,7 +130,7 @@ static float EmbeddingScaling(size_t model_dim) {
// calling application. // calling application.
// Returns new image_token_position. // Returns new image_token_position.
static HWY_NOINLINE size_t static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const ModelWeightsPtrs& weights, const ModelConfig& model_config, const ModelWeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr, MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) { size_t image_token_position = 0) {
@ -142,14 +138,14 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
if (model_config.wrapping == PromptWrapping::GEMMA_VLM && if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 && image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) { image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx), hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
x.Cols() * x.ElementBytes()); x.Cols() * x.ElementBytes());
return image_token_position + 1; return image_token_position + 1;
} }
if (model_config.wrapping == PromptWrapping::PALIGEMMA && if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx), hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
x.Cols() * x.ElementBytes()); x.Cols() * x.ElementBytes());
return image_token_position; return image_token_position;
} }
@ -169,34 +165,27 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
const auto embedding_span = const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
model_dim); model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim); MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim);
}); });
if (model_config.absolute_pe) { if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos); AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
} }
return image_token_position; return image_token_position;
} }
// Incremented in-place by Prefill* and DecodeStepT.
using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time. This is // Populates KV cache for batches of tokens from one query at a time. This is
// called if prompts are longer than the query batch size, and also in // called if prompts are longer than the query batch size, and also in
// prefix-LM mode (end > 0), which must see all tokens in one batch. // prefix-LM mode (end > 0), which must see all tokens in one batch.
static HWY_NOINLINE void PrefillTBatch( static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const RuntimeConfig& runtime_config,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, const ModelWeightsPtrs& weights,
const ModelConfig& config, const RuntimeConfig& runtime_config, Activations& activations, QBatch& qbatch,
const ModelWeightsPtrs& weights, Activations& activations, MatMulEnv& env,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.PrefillT"); PROFILER_ZONE("Gen.PrefillT");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
HWY_DASSERT(num_queries == kv_caches.size());
// Batches are important for amortizing loading weights over multiple tokens. // Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas // This is possible in prefill because we know all tokens beforehand, whereas
@ -210,19 +199,16 @@ static HWY_NOINLINE void PrefillTBatch(
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
// For each query. `qi` is within the batch, not the global query index. // For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi); non_eos.Set(qi);
// Single query at a time, so pass slices of the spans because // One query at a time, batching will be the query's prompt tokens.
// GemmaAttention will only access the first KV cache and position. QBatch qbatch_1 = qbatch.Single(qi);
QueriesPos single_query_pos(&queries_pos[qi], 1);
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
KVCaches single_kv_cache(&kv_caches[qi], 1);
const size_t prompt_size = queries_prompt[qi].size(); const size_t prompt_size = qbatch_1.Prompt(0).size();
// In autoregressive mode, we don't need to prefill the last token, so - 1. // In autoregressive mode, we don't need to prefill the last token, so - 1.
size_t prefill_this_query = prompt_size - 1; size_t prefill_this_query = prompt_size - 1;
const size_t prefix_end_this_query = queries_prefix_end[qi]; const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0);
// We can't attend beyond the prompt_size. // We can't attend beyond the prompt_size.
HWY_ASSERT(prefix_end_this_query <= prompt_size); HWY_ASSERT(prefix_end_this_query <= prompt_size);
// Special case: if the prefix includes the last token, we need to prefill // Special case: if the prefix includes the last token, we need to prefill
@ -251,9 +237,9 @@ static HWY_NOINLINE void PrefillTBatch(
// Fill activations.x (much faster than TransformerLayer). // Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0; size_t image_token_position = 0;
for (size_t ti = 0; ti < tbatch_size; ++ti) { for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti; const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti; const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt]; const int token = qbatch_1.Prompt(0)[pos_in_prompt];
image_token_position = EmbedMMToken( image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x, token, ti, pos, pos_in_prompt, config, weights, activations.x,
runtime_config.image_tokens, image_token_position); runtime_config.image_tokens, image_token_position);
@ -262,18 +248,17 @@ static HWY_NOINLINE void PrefillTBatch(
// Transformer with one batch of tokens from a single query. // Transformer with one batch of tokens from a single query.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) { ++layer_idx) {
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end, TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
layer_idx, *weights.GetLayer(layer_idx), activations, activations, qbatch_1, env);
single_kv_cache, env);
} }
// NOTE: we unconditionally call StreamToken, even if EOS. // NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) { for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti; const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti; const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt]; const int token = qbatch_1.Prompt(0)[pos_in_prompt];
if (pos_in_prompt < prompt_size - 1) { if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
} else { } else {
// The last token will be streamed later and we should only get here // The last token will be streamed later and we should only get here
// if we need to attend to the last token because it is in the prefix. // if we need to attend to the last token because it is in the prefix.
@ -281,7 +266,7 @@ static HWY_NOINLINE void PrefillTBatch(
} }
} }
queries_pos[qi] += tbatch_size; qbatch_1.MutablePos(0) += tbatch_size;
} // for tbatch_start } // for tbatch_start
if (attend_to_last_token) { if (attend_to_last_token) {
// We need to rewind the position for the last token that we only // We need to rewind the position for the last token that we only
@ -290,148 +275,125 @@ static HWY_NOINLINE void PrefillTBatch(
// decoding. Alternatives: (1) real masking; (2) always prefill the last // decoding. Alternatives: (1) real masking; (2) always prefill the last
// token and only generate the next one from the already prefilled // token and only generate the next one from the already prefilled
// activations. // activations.
queries_pos[qi] -= 1; qbatch_1.MutablePos(0) -= 1;
} }
} }
} }
// Embeds token and calls each TransformerLayer. `queries_token` is the previous // Embeds PrevToken (one from each query) and calls each TransformerLayer.
// token from each query, and `queries_pos` are their position in the sequence.
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the // Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
// token-batched `PrefillTBatch`. // token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer( static HWY_NOINLINE void Transformer(const ModelConfig& config,
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, const RuntimeConfig& runtime_config,
const QueriesPos& queries_prefix_end, const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, Activations& activations, QBatch& qbatch,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { MatMulEnv& env) {
const size_t num_queries = queries_token.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
if (HWY_UNLIKELY(runtime_config.layers_output)) { if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const float token_f = queries_token[qi]; const float token_f = qbatch.PrevToken(qi);
runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f, runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
1); "tokens", -1, &token_f, 1);
} }
} }
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
EmbedMMToken(queries_token[qi], qi, queries_pos[qi], EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
/*pos_in_prompt=*/0, config, weights, activations.x); /*pos_in_prompt=*/0, config, weights, activations.x);
} }
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end, TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
layer_idx, *weights.GetLayer(layer_idx), activations, activations, qbatch, env);
kv_caches, env);
if (HWY_UNLIKELY(runtime_config.activations_observer)) { if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_pos, layer_idx, activations); runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
} }
} }
} }
// Populates KV cache for the batch queries, one token at a time. Only called // Populates KV cache for the batch queries, one token at a time. Only called
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0. // for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
static HWY_NOINLINE void PrefillQBatch( static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const ModelConfig& config,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, const RuntimeConfig& runtime_config,
const size_t max_prompt_size, const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, Activations& activations, QBatch& qbatch,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) { hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.Prefill"); PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
HWY_DASSERT(num_queries == activations.x.Rows());
HWY_DASSERT(num_queries == kv_caches.size());
hwy::BitSet4096<> prefill_active; for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
for (size_t qi = 0; qi < num_queries; ++qi) { non_eos.Set(qi);
prefill_active.Set(qi); HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
HWY_DASSERT(queries_prefix_end[qi] == 0);
(void)queries_prefix_end;
} }
non_eos = prefill_active;
// In autoregressive mode, we don't prefill the last token, hence - 1. // In autoregressive mode, we don't prefill the last token, hence - 1.
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1; for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
++pos_in_prompt) { ++pos_in_prompt) {
// Streams that have already finished prefill no longer interleave/stream. for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
for (size_t qi = 0; qi < num_queries; ++qi) { int token = config.eos_id;
if (pos_in_prompt >= queries_prompt[qi].size() - 1) { if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) {
prefill_active.Clear(qi); token = qbatch.Prompt(qi)[pos_in_prompt];
activations.gen_tokens[qi] = config.eos_id; // Ignore StreamToken return value because requesting to stop does not
// make sense during prefill.
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi),
token, 0.0f);
} }
qbatch.PrevToken(qi) = token;
} }
// Batch := interleaved tokens, one from each non-EOS query. // The input (PrevToken) is one token from each query in the batch.
prefill_active.Foreach([&](size_t qi) {
activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt];
});
// One token from each query in the batch. Increments queries_pos.
// Do not call DecodeStepT because it computes logits for token // Do not call DecodeStepT because it computes logits for token
// probabilities, which are not required for the prompt tokens. // probabilities, which are not required for the prompt tokens.
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), Transformer(config, runtime_config, weights, activations, qbatch, env);
queries_pos, queries_prefix_end, config, runtime_config, }
weights, activations, kv_caches, env);
prefill_active.Foreach([&](size_t qi) {
const int token = queries_prompt[qi][pos_in_prompt];
// Ignore any user request to stop during prefill.
(void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi],
token, 0.0f);
queries_pos[qi] += 1;
});
} // pos_in_prompt
} }
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT, // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// and updates `non_eos` if the query is at the end of its sequence. // `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token, // query is at the end of its sequence.
const float prob, const ModelConfig& config, static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
Activations& activations, QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
hwy::BitSet4096<>& non_eos) { HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
HWY_DASSERT(non_eos.Get(qi));
// User decided to stop: set next token to primary EOS. if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) { qbatch.Pos(qi), token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below.
token = config.eos_id; token = config.eos_id;
HWY_DASSERT(config.IsEOS(token)); HWY_DASSERT(config.IsEOS(token));
} }
// Primary or secondary EOS: mark query as EOS. qbatch.PrevToken(qi) = token;
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); qbatch.MutablePos(qi) += 1;
activations.gen_tokens[qi] = token; // Primary or secondary EOS: mark query as EOS, but still increment (for
// multi-turn, we should still keep the prior EOS).
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
} }
// For a batch of queries, runs Transformer, computes logits, samples and // For a batch of queries, runs Transformer, computes logits, samples and
// streams the token. // streams the token.
static void DecodeStepT( static void DecodeStepT(const ModelConfig& config,
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const RuntimeConfig& runtime_config,
const QueriesMutablePos& queries_mutable_pos, const ModelWeightsPtrs& weights,
const QueriesPos& queries_prefix_end, const ModelConfig& config, const SampleFunc& sample_token,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, Activations& activations, QBatch& qbatch,
const SampleFunc& sample_token, Activations& activations, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows());
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == activations.x.Rows());
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), Transformer(config, runtime_config, weights, activations, qbatch, env);
queries_mutable_pos, queries_prefix_end, config, runtime_config,
weights, activations, kv_caches, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x); RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
if (HWY_UNLIKELY(runtime_config.activations_observer)) { if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_mutable_pos, -1, activations); runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
} }
{ {
@ -447,10 +409,8 @@ static void DecodeStepT(
const TokenAndProb tp = sample_token(logits, config.vocab_size); const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated(); timing_info.NotifyGenerated();
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token, StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
tp.prob, config, runtime_config, activations, non_eos); non_eos);
if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1;
}); });
} }
@ -477,46 +437,24 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
}; };
} }
// Generates one continuation for each query in `queries_prompt`, which is one // Decode: generates one continuation token for each query in `qbatch`.
// qbatch whose size is at most the `batch_size` passed to `activations` ctor. static void GenerateT(const ModelConfig& config,
// const RuntimeConfig& runtime_config,
// `queries_pos` stores the KV cache position for each query. In the first turn const ModelWeightsPtrs& weights, Activations& activations,
// of a chat, pos = 0; we increment each query's position after each token. QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
//
// `query_idx_start` is the query_idx of the first query in the batch, so that
// `StreamFunc` gets the global query index, not relative to the batch.
static void GenerateT(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`.
HWY_ASSERT(num_queries == queries_pos_in.size());
HWY_ASSERT(num_queries == queries_prefix_end.size());
HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(num_queries == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized. // Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
if (queries_pos_in[i] == 0) { if (qbatch.MutablePos(qi) == 0) {
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
} }
} }
// Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend());
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());
size_t max_prompt_size = 0; size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true; bool all_prefix_end_are_zero = true;
size_t prefill_tokens = 0; size_t prefill_tokens = 0; // only for timing.
const size_t seq_len = kv_caches[0].SeqLen(); const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = queries_prompt[qi]; const PromptTokens& prompt = qbatch.Prompt(qi);
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
// Prefill stops before size - 1 because the last prompt token is the // Prefill stops before size - 1 because the last prompt token is the
@ -526,43 +464,38 @@ static void GenerateT(
// Sanity check: prompts should not be empty, nor start with EOS. // Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0; all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
// We use a single divisor, so all sequence lengths must be the same. // We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len); HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
} }
HWY_ASSERT(prefill_tokens < seq_len); HWY_ASSERT(prefill_tokens < seq_len);
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len)); activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have // Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway. // qi loops anyway.
hwy::BitSet4096<> non_eos; hwy::BitSet4096<> non_eos; // indexed by qi
timing_info.prefill_start = hwy::platform::Now(); timing_info.prefill_start = hwy::platform::Now();
// Batch over the larger of prompt length, or queries. // Batch over the larger of prompt length, or queries.
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) { if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) {
activations.SetBatchSize(num_queries); // required before PrefillQBatch activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos, PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations,
queries_prefix_end, max_prompt_size, config, runtime_config, qbatch, env, non_eos);
weights, activations, kv_caches, env, non_eos);
} else { } else {
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos, PrefillTBatch(config, runtime_config, weights, activations, qbatch, env,
queries_prefix_end, config, runtime_config, weights, non_eos);
activations, kv_caches, env, non_eos); activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
} }
HWY_DASSERT(num_queries == non_eos.Count()); HWY_DASSERT(non_eos.Count() == qbatch.Size());
timing_info.NotifyPrefill(prefill_tokens); timing_info.NotifyPrefill(prefill_tokens);
// queries_pos have been incremented by Prefill. // queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens. // Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_token_pos_in_prompt = const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
queries_mutable_pos[qi] - queries_pos_in[qi]; StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], runtime_config, qbatch, non_eos);
queries_prompt[qi][last_token_pos_in_prompt], 0.0f,
config, runtime_config, activations, non_eos);
// No incrementing queries_mutable_pos[qi].
} }
size_t max_gen_steps = runtime_config.max_generated_tokens; size_t max_gen_steps = runtime_config.max_generated_tokens;
@ -577,10 +510,8 @@ static void GenerateT(
{ {
timing_info.generate_start = hwy::platform::Now(); timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos, DecodeStepT(config, runtime_config, weights, sample_token, activations,
queries_prefix_end, config, runtime_config, weights, qbatch, env, non_eos, timing_info);
sample_token, activations, kv_caches, env, non_eos,
timing_info);
} }
timing_info.NotifyGenerateDone(); timing_info.NotifyGenerateDone();
} }
@ -591,61 +522,38 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, KVCache& kv_cache, const ModelWeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) { MatMulEnv& env, TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1; Activations activations(config, runtime_config.prefill_tbatch_size,
const size_t qbatch_start = 0; kv_cache.SeqLen(), env.row_ptrs);
const size_t max_batch_size = AllQueries all_queries(prompt, pos, prefix_end,
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); hwy::Span<KVCache>(&kv_cache, 1));
// TODO: move into Gemma? QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
Activations activations(config, max_batch_size, env.row_ptrs); GenerateT(config, runtime_config, weights, activations, qbatch, env,
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end,
config, runtime_config, weights, activations, kv_caches, env,
timing_info); timing_info);
} }
// Splits the input into batches of at most `runtime_config.decode_qbatch_size` // Splits the input into batches of at most `runtime_config.decode_qbatch_size`
// queries, and calls `GenerateT` on each batch. // queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const QueriesPromptTokens& queries_prompt, void GenerateBatchT(const ModelConfig& config,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const KVCaches& kv_caches, const ModelWeightsPtrs& weights, AllQueries& all_queries,
MatMulEnv& env, TimingInfo& timing_info) { MatMulEnv& env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size(); const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
HWY_ASSERT(queries_pos.size() == num_queries); runtime_config.prefill_tbatch_size);
HWY_ASSERT(kv_caches.size() >= num_queries); Activations activations(config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.row_ptrs);
const size_t max_qbatch_size = runtime_config.decode_qbatch_size; for (size_t start = 0; start < all_queries.NumQueries();
const size_t max_batch_size = start += runtime_config.decode_qbatch_size) {
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
Activations activations(config, max_batch_size, env.row_ptrs); // Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, weights, activations, qbatch, env,
for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
// Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end,
config, runtime_config, weights, activations, qbatch_kv, env,
timing_info); timing_info);
} }
} }
void GenerateImageTokensT(const ModelConfig& config, void GenerateImageTokensT(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config, size_t seq_len,
const ModelWeightsPtrs& weights, const Image& image, const ModelWeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) { ImageTokens& image_tokens, MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) { if (config.vit_config.layer_configs.empty()) {
@ -656,7 +564,8 @@ void GenerateImageTokensT(const ModelConfig& config,
const size_t num_tokens = vit_config.max_seq_len; const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim); num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs); Activations prefill_activations(vit_config, num_tokens, num_tokens,
env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env); prefill_activations, env);
@ -714,36 +623,25 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
} }
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, AllQueries& all_queries,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
QueriesPos queries_prefix_end_or_zeros = queries_prefix_end;
std::vector<size_t> prefix_end_vec;
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
prefix_end_vec.resize(queries_prompt.size(), 0);
queries_prefix_end_or_zeros =
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
}
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)( HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(), weights_, all_queries, env_,
runtime_config, weights_, kv_caches, env_, timing_info); timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, size_t seq_len, const Image& image,
ImageTokens& image_tokens) const { ImageTokens& image_tokens) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)( HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
model_.Config(), runtime_config, weights_, image, image_tokens, env_); seq_len, weights_, image,
image_tokens, env_);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
} }

View File

@ -38,6 +38,129 @@
namespace gcpp { namespace gcpp {
struct PerQuery {
PromptTokens prompt;
// Position in the KV cache: initially zero for the first turn, or when
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
size_t mutable_pos;
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
// which might differ from `prompt.size() - 1` for prefix-LM.
size_t initial_pos;
// Zero for causal attention, or the end of the prefix for prefix-LM style
// attention in Paligemma.
size_t prefix_end;
KVCache& kv_cache;
// Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call.
int prev_token = 0;
};
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
struct AllQueries {
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches) {
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompt,
.mutable_pos = pos,
.initial_pos = pos,
.prefix_end = prefix_end,
.kv_cache = kv_caches[i],
});
}
}
// Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`.
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompts[i],
.mutable_pos = 0,
.initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches[i],
});
}
}
size_t NumQueries() const { return per_query_.size(); }
PerQuery& operator[](size_t query_idx) {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
const PerQuery& operator[](size_t query_idx) const {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
private:
std::vector<PerQuery> per_query_;
};
// View into AllQueries: either a batch of queries, or a single query for use
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
// reference to AllQueries.
class QBatch {
public:
QBatch(size_t start, size_t max_size, AllQueries& queries)
: start_(start),
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
}
// Returns a single-query view starting at `qi` relative to this batch.
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
size_t Size() const { return size_; }
// Returns index for use with `AllQueries` and `BatchStreamToken`.
size_t QueryIdx(size_t qi) const {
HWY_DASSERT(qi < size_);
return start_ + qi;
}
// Accessor functions to bridge the previous SoA and current AoS layout.
const PromptTokens& Prompt(size_t qi) const {
return queries_[QueryIdx(qi)].prompt;
}
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
size_t InitialPos(size_t qi) const {
return queries_[QueryIdx(qi)].initial_pos;
}
size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end;
}
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private:
size_t start_;
size_t max_size_;
AllQueries& queries_;
size_t size_;
};
struct TimingInfo { struct TimingInfo {
// be sure to populate prefill_start before calling NotifyPrefill. // be sure to populate prefill_start before calling NotifyPrefill.
void NotifyPrefill(size_t tokens) { void NotifyPrefill(size_t tokens) {
@ -100,8 +223,6 @@ struct TimingInfo {
// Returns the `MatMulEnv` after calling `SetArgs`. // Returns the `MatMulEnv` after calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args); MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
using KVCaches = hwy::Span<KVCache>;
class Gemma { class Gemma {
public: public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
@ -133,24 +254,11 @@ class Gemma {
size_t pos, size_t prefix_end, KVCache& kv_cache, size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info) const; TimingInfo& timing_info) const;
// `queries_pos` are the positions in the KV cache. Users are responsible for
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
void GenerateBatch(const RuntimeConfig& runtime_config, void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, AllQueries& all_queries, TimingInfo& timing_info) const;
const QueriesPos& queries_pos, const KVCaches& kv_caches,
TimingInfo& timing_info) const {
GenerateBatch(runtime_config, queries_prompt, queries_pos,
/*queries_prefix_end=*/{}, kv_caches, timing_info);
}
// For prefix-LM style attention, we can pass the ends of the prefixes.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info) const;
// Generates the image tokens by running the image encoder ViT. // Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config, void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
const Image& image, ImageTokens& image_tokens) const; const Image& image, ImageTokens& image_tokens) const;
private: private:

View File

@ -82,8 +82,9 @@ using ImageTokens = MatStorageT<float>;
// true to continue generation. // true to continue generation.
using StreamFunc = std::function<bool(int, float)>; using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability). // BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens, probability is 0.0f. // For prompt tokens, probability is 0.0f. Generation continues if this returns
// StreamFunc should return false to stop generation and true to continue. // true and stops if it returns false. Note that query_idx is absolute, not
// relative to the batch.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>; using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for // If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate. // tokens you don't want to generate and true for tokens you want to generate.
@ -112,8 +113,8 @@ using ActivationsObserverFunc =
// RuntimeConfig holds configuration for a single generation run. // RuntimeConfig holds configuration for a single generation run.
// TODO: move into InferenceArgs, use that directly. // TODO: move into InferenceArgs, use that directly.
struct RuntimeConfig { struct RuntimeConfig {
// If not empty, batch_stream_token is called for each token in the batch, // If non-null, `batch_stream_token` is called for each token in the batch,
// instead of stream_token. // otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
if (batch_stream_token) { if (batch_stream_token) {
return batch_stream_token(query_idx, pos, token, prob); return batch_stream_token(query_idx, pos, token, prob);
@ -189,9 +190,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"developer/debug info).\n Default = 1.", "developer/debug info).\n Default = 1.",
1); // Changed verbosity level to 1 since it's user-facing 1); // Changed verbosity level to 1 since it's user-facing
visitor(seq_len, "seq_len", size_t{2048}, visitor(seq_len, "seq_len", size_t{8192},
"Sequence length, capped by ModelConfig.max_seq_len."); "Sequence length, capped by ModelConfig.max_seq_len.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
"Maximum number of tokens to generate."); "Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},

View File

@ -39,16 +39,10 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
// Different functions use different naming conventions for the number of void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
size_t griffin_layer, Activations& activations,
const LayerWeightsPtrs* layer_weights, const LayerWeightsPtrs* layer_weights,
const KVCaches& kv_caches, MatMulEnv& env) { Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Griffin"); PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = env.ctx.pools.Pool(0); hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -64,9 +58,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
const size_t kHeadDim = model_dim / heads; const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim; const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_queries = queries_pos.size(); const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries)); const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
const size_t num_interleaved = num_tokens * num_queries;
// X / Y linear layers. // X / Y linear layers.
// TODO: MatMul // TODO: MatMul
@ -91,17 +84,17 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
// Conv1D. // Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
// cache[i] = input at time t-i. // cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth]; float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x; cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) { for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] = cache[i] =
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) + qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
} }
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
@ -127,16 +120,16 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
// RGLRU // RGLRU
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx); const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx; const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx); float* HWY_RESTRICT y = activations.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx); float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx); float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state = float* HWY_RESTRICT rnn_state =
kv_caches[query_idx].rglru_cache.Row(griffin_layer); qbatch.KV(qi).rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim; size_t head_offset = head * kHeadDim;

View File

@ -26,13 +26,13 @@
namespace gcpp { namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ #define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
size_t griffin_layer, Activations& activations, \ const LayerWeightsPtrs* layer_weights, \
const LayerWeightsPtrs* layer_weights, \ Activations& activations, QBatch& qbatch, \
const KVCaches& kv_caches, MatMulEnv& env); \ MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the // Function declarations for each SIMD target. Allows direct call from the

View File

@ -120,7 +120,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.verbosity = inference.verbosity, .verbosity = inference.verbosity,
.use_spinning = threading.spin}; .use_spinning = threading.spin};
double image_tokens_start = hwy::platform::Now(); double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, image, image_tokens); gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens);
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start; double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr, fprintf(stderr,

View File

@ -57,7 +57,8 @@ class PaliGemmaTest : public ::testing::Test {
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
.verbosity = 0}; .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_); gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(),
image, *image_tokens_);
} }
std::string GemmaReply(const std::string& prompt_text) const { std::string GemmaReply(const std::string& prompt_text) const {
@ -107,12 +108,11 @@ class PaliGemmaTest : public ::testing::Test {
TEST_F(PaliGemmaTest, QueryObjects) { TEST_F(PaliGemmaTest, QueryObjects) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
const char* question = "answer en What objects are in the image?"; const char* question = "answer en What objects are in the image?";
const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224 // 3B PT/Mix 224, 10B Mix 224
const char* expected_substring = "Building, Tower";
const Model model = s_env->GetGemma()->GetModelConfig().model; const Model model = s_env->GetGemma()->GetModelConfig().model;
if (model == Model::PALIGEMMA2_3B_448) { if (model == Model::PALIGEMMA2_3B_448) {
expected_substring = "Lake."; expected_substring = "Lake.";
} else if (model == Model::PALIGEMMA2_3B_224) {
expected_substring = "Cloud, Water.";
} else if (model == Model::PALIGEMMA2_10B_224) { } else if (model == Model::PALIGEMMA2_10B_224) {
expected_substring = "Building."; expected_substring = "Building.";
} }

View File

@ -190,7 +190,8 @@ class GemmaModel {
gcpp::MatPadding::kOdd)); gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0}; .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_); gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(),
c_image, *image_tokens_);
} }
// Generates a response to the given prompt, using the last set image. // Generates a response to the given prompt, using the last set image.