Major refactor: clarify query_idx (global) vs qi. Refs #607

Fix missing pos increment for last prefill and check that in gemma_test.
Thanks to @ufownl for pointing this out.

Change argument lists to QBatch with accessors.
Increase default seq_len to 8k.

PiperOrigin-RevId: 771937385
This commit is contained in:
Jan Wassenberg 2025-06-16 02:41:30 -07:00 committed by Copybara-Service
parent 2c72ff2aa5
commit e5c81f64a1
15 changed files with 399 additions and 416 deletions

View File

@ -109,16 +109,19 @@ void GemmaEnv::QueryModel(
}
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) {
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0);
std::vector<QueryResult> res(num_queries);
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
size_t query_index, size_t pos,
int token, float) {
const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index,
const size_t pos,
const int token, float) {
HWY_ASSERT(query_index < num_queries);
std::string token_text;
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].response.append(token_text);
HWY_ASSERT(pos == res[query_index].tokens_generated);
res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) {
@ -126,6 +129,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
}
return true;
};
runtime_config_.batch_stream_token = batch_stream_token;
if (runtime_config_.verbosity >= 2) {
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
@ -137,13 +141,11 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
while (kv_caches_.size() < num_queries) {
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
}
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token;
std::vector<size_t> queries_pos(num_queries, 0);
gemma_.GenerateBatch(runtime_config_, queries_prompt,
QueriesPos(queries_pos.data(), num_queries),
KVCaches(&kv_caches_[0], num_queries), timing_info);
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
return res;
}

View File

@ -88,8 +88,10 @@ class GemmaEnv {
// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
QueryResult QueryModel(const std::vector<int>& tokens);
// The default prefix_end means "causal attention".
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt);
const QueriesPromptTokens& queries_prompt,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
// Adds turn structure to input, tokenizes and calls the above overload.
QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel(

View File

@ -101,9 +101,11 @@ TEST_F(GemmaTest, Multiturn) {
const ModelConfig& config = model->GetModelConfig();
size_t abs_pos = 0;
std::string response;
auto stream_token = [&](int token, float) {
if (config.IsEOS(token)) return true;
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
HWY_ASSERT(query_idx == 0);
HWY_ASSERT(pos == abs_pos);
++abs_pos;
if (config.IsEOS(token)) return true;
std::string token_text;
EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
@ -115,7 +117,7 @@ TEST_F(GemmaTest, Multiturn) {
.temperature = 0.0f,
.gen = &s_env->MutableGen(),
.verbosity = 2,
.stream_token = stream_token,
.batch_stream_token = stream_token,
};
TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual.

View File

@ -42,13 +42,12 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
}
struct Activations {
Activations(const ModelConfig& config, size_t batch_size,
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: weights_config(config),
layer_config(config.layer_configs[0]),
div_seq_len(static_cast<uint32_t>(config.max_seq_len)),
div_seq_len(static_cast<uint32_t>(seq_len)),
is_griffin(config.model == Model::GRIFFIN_2B),
query_scale(ChooseQueryScale(config)),
x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
@ -63,10 +62,7 @@ struct Activations {
pre_att_rms_out("pre_att_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
att("att",
Extents2D(batch_size,
layer_config.heads * div_seq_len.GetDivisor()),
pad_),
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_),
att_out(
"att_out",
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
@ -99,7 +95,7 @@ struct Activations {
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)),
gen_tokens(batch_size) {
query_scale(ChooseQueryScale(config)) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
@ -138,8 +134,6 @@ struct Activations {
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
gen_tokens.resize(batch_size);
}
bool IsGlobalLayer(size_t layer_idx) const {
@ -151,7 +145,6 @@ struct Activations {
const LayerConfig& layer_config;
hwy::Divisor div_seq_len;
bool is_griffin;
float query_scale;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;
@ -182,9 +175,7 @@ struct Activations {
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
// Storage for the last generated token from each query, passed to the next
// Transformer() call.
std::vector<int> gen_tokens; // one per query in the batch
float query_scale;
};
} // namespace gcpp

View File

@ -153,15 +153,12 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
return pos - HWY_MIN(att_window_size - 1, pos);
}
void DotSoftmaxWeightedSum(const size_t num_tokens,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const size_t layer_idx,
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches,
Activations& activations, QBatch& qbatch,
NestedPools& pools) {
PROFILER_ZONE("Gen.Attention.DotSoftmax");
const hwy::Divisor div_queries(queries_pos.size());
const hwy::Divisor div_qbatch(qbatch.Size());
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
@ -176,7 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
// For each head/token/query, compute Q.K, softmax, and weighted V.
// Statically partition token/query across packages.
const size_t num_tq = num_tokens * div_queries.GetDivisor();
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
const IndexRangePartition tq_ranges =
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
ParallelizeOneRange(
@ -185,17 +182,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
pools.AllClusters(pkg_idx).Run(
tq_range.begin(), tq_range.end(),
[&](const size_t tq_idx, const size_t cluster_idx) {
const size_t query_idx = div_queries.Remainder(tq_idx);
const size_t batch_idx = div_queries.Divide(tq_idx);
auto& kv_cache = kv_caches[query_idx].kv_cache;
const size_t qi = div_qbatch.Remainder(tq_idx);
const size_t batch_idx = div_qbatch.Divide(tq_idx);
auto& kv_cache = qbatch.KV(qi).kv_cache;
// Find the token position in the query and calculate
// the range of cache positions to attend to.
const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t start_pos =
StartPos(pos, activations.weights_config, layer_idx);
size_t last_pos = pos;
const size_t prefix_end = queries_prefix_end[query_idx];
const size_t prefix_end = qbatch.PrefixEnd(qi);
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
// last_pos in QDotK and WeightedSumV is inclusive.
last_pos = prefix_end - 1;
@ -235,14 +232,21 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
});
}
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations,
const QBatch& qbatch, const int flags,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV");
const hwy::Divisor div_queries(queries_pos.size());
const size_t num_interleaved = num_tokens * div_queries.GetDivisor();
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads;
@ -260,13 +264,12 @@ static HWY_INLINE void ComputeQKV(
layer.qkv_einsum_w2.Rows()));
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_queries.Remainder(interleaved_idx);
const size_t batch_idx = div_queries.Divide(interleaved_idx);
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t cache_pos =
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
kv_caches[query_idx].kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size);
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
@ -280,11 +283,11 @@ static HWY_INLINE void ComputeQKV(
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kv_heads;
const size_t interleaved_idx = task / kv_heads;
const size_t query_idx = div_queries.Remainder(interleaved_idx);
const size_t batch_idx = div_queries.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = kv_caches[query_idx].kv_cache;
auto& kv_cache = qbatch.KV(qi).kv_cache;
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
@ -320,35 +323,18 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
activations.att_sums);
}
// `queries_prefix_end` can be null (interpreted as all-zero) for standard
// causal attention, and must be non-null for prefix-LM style attention.
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos* queries_prefix_end,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches,
MatMulEnv& env, int flags) {
const size_t num_queries = queries_pos.size();
HWY_DASSERT(num_queries <= kv_caches.size());
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
QBatch& qbatch, MatMulEnv& env, int flags) {
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
(void)layer_config; // only used in HWY_DASSERT
std::vector<size_t> queries_prefix_end_vec;
QueriesPos queries_prefix_end_span;
if (queries_prefix_end == nullptr) {
queries_prefix_end_vec.assign(num_queries, 0);
queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(),
queries_prefix_end_vec.size());
queries_prefix_end = &queries_prefix_end_span;
}
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
flags, env);
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
layer, activations, kv_caches, env.ctx.pools);
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
env.ctx.pools);
SumHeads(layer, activations, env);
}

View File

@ -35,18 +35,14 @@ namespace gcpp {
const Activations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, \
const QueriesPos& queries_pos, \
const QueriesPos& queries_prefix_end, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
Activations& activations, \
const KVCaches& kv_caches, NestedPools& pools); \
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
const LayerWeightsPtrs& layer, \
Activations& activations, QBatch& qbatch, \
NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
const QueriesPos* queries_prefix_end, \
const size_t layer_idx, const LayerWeightsPtrs& layer, \
Activations& activations, const KVCaches& kv_caches, \
MatMulEnv& env, int flags); \
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
const LayerWeightsPtrs& layer, Activations& activations, \
QBatch& qbatch, MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -205,7 +205,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// RuntimeConfig runtime_config = { ... }; // This was already defined
double image_tokens_start = hwy::platform::Now();
// Pass the populated image object to GenerateImageTokens
model.GenerateImageTokens(runtime_config, image, image_tokens);
model.GenerateImageTokens(runtime_config,
active_conversation->kv_cache->SeqLen(), image,
image_tokens);
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
ss.str("");

View File

@ -45,7 +45,6 @@
#include "gemma/configs.h"
#include "gemma/model_store.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h"
#include "io/blob_store.h"
#include "io/io.h" // Path
@ -62,14 +61,11 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, size_t num_tokens,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx,
layer, activations, kv_caches, env,
GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
@ -77,23 +73,23 @@ void Attention(LayerAttentionType type, size_t num_tokens,
// so map `layer` to the Griffin layer index.
const size_t griffin_layer =
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
&layer, kv_caches, env);
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
env);
}
}
static HWY_NOINLINE void TransformerLayer(
const size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations,
QBatch& qbatch, MatMulEnv& env) {
const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out);
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
layer_idx, layer, activations, kv_caches, env);
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
qbatch, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums);
@ -134,7 +130,7 @@ static float EmbeddingScaling(size_t model_dim) {
// calling application.
// Returns new image_token_position.
static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
@ -142,14 +138,14 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
x.Cols() * x.ElementBytes());
return image_token_position + 1;
}
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
x.Cols() * x.ElementBytes());
return image_token_position;
}
@ -169,34 +165,27 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim);
});
if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
}
return image_token_position;
}
// Incremented in-place by Prefill* and DecodeStepT.
using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time. This is
// called if prompts are longer than the query batch size, and also in
// prefix-LM mode (end > 0), which must see all tokens in one batch.
static HWY_NOINLINE void PrefillTBatch(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.PrefillT");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
HWY_DASSERT(num_queries == kv_caches.size());
// Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas
@ -210,19 +199,16 @@ static HWY_NOINLINE void PrefillTBatch(
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
// Single query at a time, so pass slices of the spans because
// GemmaAttention will only access the first KV cache and position.
QueriesPos single_query_pos(&queries_pos[qi], 1);
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
KVCaches single_kv_cache(&kv_caches[qi], 1);
// One query at a time, batching will be the query's prompt tokens.
QBatch qbatch_1 = qbatch.Single(qi);
const size_t prompt_size = queries_prompt[qi].size();
const size_t prompt_size = qbatch_1.Prompt(0).size();
// In autoregressive mode, we don't need to prefill the last token, so - 1.
size_t prefill_this_query = prompt_size - 1;
const size_t prefix_end_this_query = queries_prefix_end[qi];
const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0);
// We can't attend beyond the prompt_size.
HWY_ASSERT(prefix_end_this_query <= prompt_size);
// Special case: if the prefix includes the last token, we need to prefill
@ -251,9 +237,9 @@ static HWY_NOINLINE void PrefillTBatch(
// Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0;
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x,
runtime_config.image_tokens, image_token_position);
@ -262,18 +248,17 @@ static HWY_NOINLINE void PrefillTBatch(
// Transformer with one batch of tokens from a single query.
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
layer_idx, *weights.GetLayer(layer_idx), activations,
single_kv_cache, env);
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch_1, env);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const size_t pos = qbatch_1.Pos(0) + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
if (pos_in_prompt < prompt_size - 1) {
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
} else {
// The last token will be streamed later and we should only get here
// if we need to attend to the last token because it is in the prefix.
@ -281,7 +266,7 @@ static HWY_NOINLINE void PrefillTBatch(
}
}
queries_pos[qi] += tbatch_size;
qbatch_1.MutablePos(0) += tbatch_size;
} // for tbatch_start
if (attend_to_last_token) {
// We need to rewind the position for the last token that we only
@ -290,148 +275,125 @@ static HWY_NOINLINE void PrefillTBatch(
// decoding. Alternatives: (1) real masking; (2) always prefill the last
// token and only generate the next one from the already prefilled
// activations.
queries_pos[qi] -= 1;
qbatch_1.MutablePos(0) -= 1;
}
}
}
// Embeds token and calls each TransformerLayer. `queries_token` is the previous
// token from each query, and `queries_pos` are their position in the sequence.
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
// token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
const size_t num_queries = queries_token.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
static HWY_NOINLINE void Transformer(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < num_queries; ++qi) {
const float token_f = queries_token[qi];
runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f,
1);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const float token_f = qbatch.PrevToken(qi);
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
"tokens", -1, &token_f, 1);
}
}
for (size_t qi = 0; qi < num_queries; ++qi) {
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
/*pos_in_prompt=*/0, config, weights, activations.x);
}
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
layer_idx, *weights.GetLayer(layer_idx), activations,
kv_caches, env);
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
activations, qbatch, env);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_pos, layer_idx, activations);
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
activations);
}
}
}
// Populates KV cache for the batch queries, one token at a time. Only called
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
static HWY_NOINLINE void PrefillQBatch(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t max_prompt_size, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
HWY_DASSERT(num_queries == activations.x.Rows());
HWY_DASSERT(num_queries == kv_caches.size());
hwy::BitSet4096<> prefill_active;
for (size_t qi = 0; qi < num_queries; ++qi) {
prefill_active.Set(qi);
HWY_DASSERT(queries_prefix_end[qi] == 0);
(void)queries_prefix_end;
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
}
non_eos = prefill_active;
// In autoregressive mode, we don't prefill the last token, hence - 1.
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
++pos_in_prompt) {
// Streams that have already finished prefill no longer interleave/stream.
for (size_t qi = 0; qi < num_queries; ++qi) {
if (pos_in_prompt >= queries_prompt[qi].size() - 1) {
prefill_active.Clear(qi);
activations.gen_tokens[qi] = config.eos_id;
}
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
int token = config.eos_id;
if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) {
token = qbatch.Prompt(qi)[pos_in_prompt];
// Ignore StreamToken return value because requesting to stop does not
// make sense during prefill.
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi),
token, 0.0f);
}
// Batch := interleaved tokens, one from each non-EOS query.
prefill_active.Foreach([&](size_t qi) {
activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt];
});
qbatch.PrevToken(qi) = token;
}
// One token from each query in the batch. Increments queries_pos.
// The input (PrevToken) is one token from each query in the batch.
// Do not call DecodeStepT because it computes logits for token
// probabilities, which are not required for the prompt tokens.
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
queries_pos, queries_prefix_end, config, runtime_config,
weights, activations, kv_caches, env);
prefill_active.Foreach([&](size_t qi) {
const int token = queries_prompt[qi][pos_in_prompt];
// Ignore any user request to stop during prefill.
(void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi],
token, 0.0f);
queries_pos[qi] += 1;
});
} // pos_in_prompt
Transformer(config, runtime_config, weights, activations, qbatch, env);
}
}
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
// and updates `non_eos` if the query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
const float prob, const ModelConfig& config,
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
// query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
Activations& activations,
hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi));
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
// User decided to stop: set next token to primary EOS.
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
qbatch.Pos(qi), token, prob))) {
// User decided to stop: set token to primary EOS to trigger IsEOS below.
token = config.eos_id;
HWY_DASSERT(config.IsEOS(token));
}
// Primary or secondary EOS: mark query as EOS.
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
qbatch.PrevToken(qi) = token;
qbatch.MutablePos(qi) += 1;
activations.gen_tokens[qi] = token;
// Primary or secondary EOS: mark query as EOS, but still increment (for
// multi-turn, we should still keep the prior EOS).
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
}
// For a batch of queries, runs Transformer, computes logits, samples and
// streams the token.
static void DecodeStepT(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_mutable_pos,
const QueriesPos& queries_prefix_end, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
const SampleFunc& sample_token, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
static void DecodeStepT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == activations.x.Rows());
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, config, runtime_config,
weights, activations, kv_caches, env);
Transformer(config, runtime_config, weights, activations, qbatch, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_mutable_pos, -1, activations);
runtime_config.activations_observer(
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
}
{
@ -447,10 +409,8 @@ static void DecodeStepT(
const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token,
tp.prob, config, runtime_config, activations, non_eos);
if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1;
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
non_eos);
});
}
@ -477,46 +437,24 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
};
}
// Generates one continuation for each query in `queries_prompt`, which is one
// qbatch whose size is at most the `batch_size` passed to `activations` ctor.
//
// `queries_pos` stores the KV cache position for each query. In the first turn
// of a chat, pos = 0; we increment each query's position after each token.
//
// `query_idx_start` is the query_idx of the first query in the batch, so that
// `StreamFunc` gets the global query index, not relative to the batch.
static void GenerateT(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end,
const ModelConfig& config, const RuntimeConfig& runtime_config,
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`.
HWY_ASSERT(num_queries == queries_pos_in.size());
HWY_ASSERT(num_queries == queries_prefix_end.size());
HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(num_queries == kv_caches.size());
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) {
if (queries_pos_in[i] == 0) {
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
if (qbatch.MutablePos(qi) == 0) {
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
}
}
// Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend());
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t prefill_tokens = 0;
const size_t seq_len = kv_caches[0].SeqLen();
for (size_t qi = 0; qi < num_queries; ++qi) {
const PromptTokens& prompt = queries_prompt[qi];
size_t prefill_tokens = 0; // only for timing.
const size_t seq_len = qbatch.KV(0).SeqLen();
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const PromptTokens& prompt = qbatch.Prompt(qi);
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
// Prefill stops before size - 1 because the last prompt token is the
@ -526,43 +464,38 @@ static void GenerateT(
// Sanity check: prompts should not be empty, nor start with EOS.
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
// We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len);
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
}
HWY_ASSERT(prefill_tokens < seq_len);
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.
hwy::BitSet4096<> non_eos;
hwy::BitSet4096<> non_eos; // indexed by qi
timing_info.prefill_start = hwy::platform::Now();
// Batch over the larger of prompt length, or queries.
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
activations.SetBatchSize(num_queries); // required before PrefillQBatch
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, max_prompt_size, config, runtime_config,
weights, activations, kv_caches, env, non_eos);
if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) {
activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch
PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations,
qbatch, env, non_eos);
} else {
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, config, runtime_config, weights,
activations, kv_caches, env, non_eos);
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
PrefillTBatch(config, runtime_config, weights, activations, qbatch, env,
non_eos);
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
}
HWY_DASSERT(num_queries == non_eos.Count());
HWY_DASSERT(non_eos.Count() == qbatch.Size());
timing_info.NotifyPrefill(prefill_tokens);
// queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < num_queries; ++qi) {
const size_t last_token_pos_in_prompt =
queries_mutable_pos[qi] - queries_pos_in[qi];
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi],
queries_prompt[qi][last_token_pos_in_prompt], 0.0f,
config, runtime_config, activations, non_eos);
// No incrementing queries_mutable_pos[qi].
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
runtime_config, qbatch, non_eos);
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
@ -577,10 +510,8 @@ static void GenerateT(
{
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, config, runtime_config, weights,
sample_token, activations, kv_caches, env, non_eos,
timing_info);
DecodeStepT(config, runtime_config, weights, sample_token, activations,
qbatch, env, non_eos, timing_info);
}
timing_info.NotifyGenerateDone();
}
@ -591,61 +522,38 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;
Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.row_ptrs);
const size_t max_batch_size =
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
// TODO: move into Gemma?
Activations activations(config, max_batch_size, env.row_ptrs);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end,
config, runtime_config, weights, activations, kv_caches, env,
AllQueries all_queries(prompt, pos, prefix_end,
hwy::Span<KVCache>(&kv_cache, 1));
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
GenerateT(config, runtime_config, weights, activations, qbatch, env,
timing_info);
}
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
// queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const ModelConfig& config,
void GenerateBatchT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const KVCaches& kv_caches,
const ModelWeightsPtrs& weights, AllQueries& all_queries,
MatMulEnv& env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries);
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size,
all_queries[0].kv_cache.SeqLen(), env.row_ptrs);
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, env.row_ptrs);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
// Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end,
config, runtime_config, weights, activations, qbatch_kv, env,
for (size_t start = 0; start < all_queries.NumQueries();
start += runtime_config.decode_qbatch_size) {
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
// Generate a batch of one token for each of `qbatch.Size()` queries.
GenerateT(config, runtime_config, weights, activations, qbatch, env,
timing_info);
}
}
void GenerateImageTokensT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const RuntimeConfig& runtime_config, size_t seq_len,
const ModelWeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) {
@ -656,7 +564,8 @@ void GenerateImageTokensT(const ModelConfig& config,
const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size =
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs);
Activations prefill_activations(vit_config, num_tokens, num_tokens,
env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env);
@ -714,36 +623,25 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
}
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches,
AllQueries& all_queries,
TimingInfo& timing_info) const {
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
QueriesPos queries_prefix_end_or_zeros = queries_prefix_end;
std::vector<size_t> prefix_end_vec;
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
prefix_end_vec.resize(queries_prompt.size(), 0);
queries_prefix_end_or_zeros =
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
}
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(),
runtime_config, weights_, kv_caches, env_, timing_info);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
weights_, all_queries, env_,
timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image,
size_t seq_len, const Image& image,
ImageTokens& image_tokens) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
model_.Config(), runtime_config, weights_, image, image_tokens, env_);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
seq_len, weights_, image,
image_tokens, env_);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

View File

@ -38,6 +38,129 @@
namespace gcpp {
struct PerQuery {
PromptTokens prompt;
// Position in the KV cache: initially zero for the first turn, or when
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
size_t mutable_pos;
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
// which might differ from `prompt.size() - 1` for prefix-LM.
size_t initial_pos;
// Zero for causal attention, or the end of the prefix for prefix-LM style
// attention in Paligemma.
size_t prefix_end;
KVCache& kv_cache;
// Previous token generated for this query, or the last prompt token. Will be
// fed into the next Transformer() call.
int prev_token = 0;
};
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
struct AllQueries {
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const hwy::Span<KVCache>& kv_caches) {
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompt,
.mutable_pos = pos,
.initial_pos = pos,
.prefix_end = prefix_end,
.kv_cache = kv_caches[i],
});
}
}
// Batch of queries with initial position set to zero. Causal attention
// is requested via empty or all-zero `prefix_end`.
AllQueries(
const hwy::Span<const PromptTokens>& prompts,
const hwy::Span<KVCache>& kv_caches,
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
HWY_ASSERT(prompts.size() == kv_caches.size());
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
per_query_.reserve(kv_caches.size());
for (size_t i = 0; i < kv_caches.size(); ++i) {
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
per_query_.push_back(PerQuery{
.prompt = prompts[i],
.mutable_pos = 0,
.initial_pos = 0,
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
.kv_cache = kv_caches[i],
});
}
}
size_t NumQueries() const { return per_query_.size(); }
PerQuery& operator[](size_t query_idx) {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
const PerQuery& operator[](size_t query_idx) const {
HWY_DASSERT(query_idx < NumQueries());
return per_query_[query_idx];
}
private:
std::vector<PerQuery> per_query_;
};
// View into AllQueries: either a batch of queries, or a single query for use
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
// reference to AllQueries.
class QBatch {
public:
QBatch(size_t start, size_t max_size, AllQueries& queries)
: start_(start),
max_size_(max_size),
queries_(queries),
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
HWY_DASSERT(size_ != 0);
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
}
// Returns a single-query view starting at `qi` relative to this batch.
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
size_t Size() const { return size_; }
// Returns index for use with `AllQueries` and `BatchStreamToken`.
size_t QueryIdx(size_t qi) const {
HWY_DASSERT(qi < size_);
return start_ + qi;
}
// Accessor functions to bridge the previous SoA and current AoS layout.
const PromptTokens& Prompt(size_t qi) const {
return queries_[QueryIdx(qi)].prompt;
}
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
size_t InitialPos(size_t qi) const {
return queries_[QueryIdx(qi)].initial_pos;
}
size_t PrefixEnd(size_t qi) const {
return queries_[QueryIdx(qi)].prefix_end;
}
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
private:
size_t start_;
size_t max_size_;
AllQueries& queries_;
size_t size_;
};
struct TimingInfo {
// be sure to populate prefill_start before calling NotifyPrefill.
void NotifyPrefill(size_t tokens) {
@ -100,8 +223,6 @@ struct TimingInfo {
// Returns the `MatMulEnv` after calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
using KVCaches = hwy::Span<KVCache>;
class Gemma {
public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
@ -133,24 +254,11 @@ class Gemma {
size_t pos, size_t prefix_end, KVCache& kv_cache,
TimingInfo& timing_info) const;
// `queries_pos` are the positions in the KV cache. Users are responsible for
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const KVCaches& kv_caches,
TimingInfo& timing_info) const {
GenerateBatch(runtime_config, queries_prompt, queries_pos,
/*queries_prefix_end=*/{}, kv_caches, timing_info);
}
// For prefix-LM style attention, we can pass the ends of the prefixes.
void GenerateBatch(const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, TimingInfo& timing_info) const;
AllQueries& all_queries, TimingInfo& timing_info) const;
// Generates the image tokens by running the image encoder ViT.
void GenerateImageTokens(const RuntimeConfig& runtime_config,
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
const Image& image, ImageTokens& image_tokens) const;
private:

View File

@ -82,8 +82,9 @@ using ImageTokens = MatStorageT<float>;
// true to continue generation.
using StreamFunc = std::function<bool(int, float)>;
// BatchStreamFunc is called with (query_idx, pos, token, probability).
// For prompt tokens, probability is 0.0f.
// StreamFunc should return false to stop generation and true to continue.
// For prompt tokens, probability is 0.0f. Generation continues if this returns
// true and stops if it returns false. Note that query_idx is absolute, not
// relative to the batch.
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
// If not empty, AcceptFunc is called with token. It should return false for
// tokens you don't want to generate and true for tokens you want to generate.
@ -112,8 +113,8 @@ using ActivationsObserverFunc =
// RuntimeConfig holds configuration for a single generation run.
// TODO: move into InferenceArgs, use that directly.
struct RuntimeConfig {
// If not empty, batch_stream_token is called for each token in the batch,
// instead of stream_token.
// If non-null, `batch_stream_token` is called for each token in the batch,
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
if (batch_stream_token) {
return batch_stream_token(query_idx, pos, token, prob);
@ -189,9 +190,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"developer/debug info).\n Default = 1.",
1); // Changed verbosity level to 1 since it's user-facing
visitor(seq_len, "seq_len", size_t{2048},
visitor(seq_len, "seq_len", size_t{8192},
"Sequence length, capped by ModelConfig.max_seq_len.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
"Maximum number of tokens to generate.");
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},

View File

@ -39,16 +39,10 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
size_t griffin_layer, Activations& activations,
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
const LayerWeightsPtrs* layer_weights,
const KVCaches& kv_caches, MatMulEnv& env) {
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
@ -64,9 +58,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
const size_t kHeadDim = model_dim / heads;
const size_t kMatrixSize = kHeadDim * kHeadDim;
const size_t num_queries = queries_pos.size();
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
const size_t num_interleaved = num_tokens * num_queries;
const size_t num_interleaved = num_tokens * qbatch.Size();
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
// X / Y linear layers.
// TODO: MatMul
@ -91,17 +84,17 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
// Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
@ -127,16 +120,16 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
// RGLRU
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t qi = div_qbatch.Remainder(interleaved_idx);
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
const size_t pos = qbatch.Pos(qi) + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
float* HWY_RESTRICT y = activations.griffin_y.Row(qi);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi);
float* HWY_RESTRICT rnn_state =
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim;

View File

@ -28,10 +28,10 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \
size_t griffin_layer, Activations& activations, \
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
const LayerWeightsPtrs* layer_weights, \
const KVCaches& kv_caches, MatMulEnv& env); \
Activations& activations, QBatch& qbatch, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -120,7 +120,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.verbosity = inference.verbosity,
.use_spinning = threading.spin};
double image_tokens_start = hwy::platform::Now();
gemma.GenerateImageTokens(runtime_config, image, image_tokens);
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
image_tokens);
if (inference.verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,

View File

@ -57,7 +57,8 @@ class PaliGemmaTest : public ::testing::Test {
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(),
image, *image_tokens_);
}
std::string GemmaReply(const std::string& prompt_text) const {
@ -107,12 +108,11 @@ class PaliGemmaTest : public ::testing::Test {
TEST_F(PaliGemmaTest, QueryObjects) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const char* question = "answer en What objects are in the image?";
const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224
// 3B PT/Mix 224, 10B Mix 224
const char* expected_substring = "Building, Tower";
const Model model = s_env->GetGemma()->GetModelConfig().model;
if (model == Model::PALIGEMMA2_3B_448) {
expected_substring = "Lake.";
} else if (model == Model::PALIGEMMA2_3B_224) {
expected_substring = "Cloud, Water.";
} else if (model == Model::PALIGEMMA2_10B_224) {
expected_substring = "Building.";
}

View File

@ -190,7 +190,8 @@ class GemmaModel {
gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_);
gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(),
c_image, *image_tokens_);
}
// Generates a response to the given prompt, using the last set image.