mirror of https://github.com/google/gemma.cpp.git
Major refactor: clarify query_idx (global) vs qi. Refs #607
Fix missing pos increment for last prefill and check that in gemma_test. Thanks to @ufownl for pointing this out. Change argument lists to QBatch with accessors. Increase default seq_len to 8k. PiperOrigin-RevId: 771937385
This commit is contained in:
parent
2c72ff2aa5
commit
e5c81f64a1
|
|
@ -109,16 +109,19 @@ void GemmaEnv::QueryModel(
|
|||
}
|
||||
|
||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt) {
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
std::vector<QueryResult> res(num_queries);
|
||||
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
|
||||
size_t query_index, size_t pos,
|
||||
int token, float) {
|
||||
const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index,
|
||||
const size_t pos,
|
||||
const int token, float) {
|
||||
HWY_ASSERT(query_index < num_queries);
|
||||
std::string token_text;
|
||||
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res[query_index].response.append(token_text);
|
||||
HWY_ASSERT(pos == res[query_index].tokens_generated);
|
||||
res[query_index].tokens_generated += 1;
|
||||
if (res[query_index].tokens_generated ==
|
||||
queries_prompt[query_index].size()) {
|
||||
|
|
@ -126,6 +129,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
}
|
||||
return true;
|
||||
};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
if (runtime_config_.verbosity >= 2) {
|
||||
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||
|
|
@ -137,13 +141,11 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
while (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
|
||||
}
|
||||
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
||||
|
||||
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
std::vector<size_t> queries_pos(num_queries, 0);
|
||||
gemma_.GenerateBatch(runtime_config_, queries_prompt,
|
||||
QueriesPos(queries_pos.data(), num_queries),
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -88,8 +88,10 @@ class GemmaEnv {
|
|||
// Runs inference on the given input and returns the top-1 result string and
|
||||
// the number of tokens that were generated.
|
||||
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||
// The default prefix_end means "causal attention".
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
const QueriesPromptTokens& queries_prompt);
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||
QueryResult QueryModel(std::string& input);
|
||||
std::vector<QueryResult> BatchQueryModel(
|
||||
|
|
|
|||
|
|
@ -101,9 +101,11 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
const ModelConfig& config = model->GetModelConfig();
|
||||
size_t abs_pos = 0;
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
if (config.IsEOS(token)) return true;
|
||||
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
|
||||
HWY_ASSERT(query_idx == 0);
|
||||
HWY_ASSERT(pos == abs_pos);
|
||||
++abs_pos;
|
||||
if (config.IsEOS(token)) return true;
|
||||
std::string token_text;
|
||||
EXPECT_TRUE(
|
||||
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
|
|
@ -115,7 +117,7 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
.temperature = 0.0f,
|
||||
.gen = &s_env->MutableGen(),
|
||||
.verbosity = 2,
|
||||
.stream_token = stream_token,
|
||||
.batch_stream_token = stream_token,
|
||||
};
|
||||
TimingInfo timing_info{.verbosity = 0};
|
||||
// First "say" something slightly unusual.
|
||||
|
|
|
|||
|
|
@ -42,13 +42,12 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
|
|||
}
|
||||
|
||||
struct Activations {
|
||||
Activations(const ModelConfig& config, size_t batch_size,
|
||||
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
div_seq_len(static_cast<uint32_t>(config.max_seq_len)),
|
||||
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||
query_scale(ChooseQueryScale(config)),
|
||||
|
||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||
|
|
@ -63,10 +62,7 @@ struct Activations {
|
|||
|
||||
pre_att_rms_out("pre_att_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
att("att",
|
||||
Extents2D(batch_size,
|
||||
layer_config.heads * div_seq_len.GetDivisor()),
|
||||
pad_),
|
||||
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_),
|
||||
att_out(
|
||||
"att_out",
|
||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
||||
|
|
@ -99,7 +95,7 @@ struct Activations {
|
|||
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
||||
1000000.0)),
|
||||
|
||||
gen_tokens(batch_size) {
|
||||
query_scale(ChooseQueryScale(config)) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
// For MatMul outputs, precompute their row pointers.
|
||||
|
|
@ -138,8 +134,6 @@ struct Activations {
|
|||
griffin_gate_x.OverrideRows(batch_size);
|
||||
griffin_multiplier.OverrideRows(batch_size);
|
||||
}
|
||||
|
||||
gen_tokens.resize(batch_size);
|
||||
}
|
||||
|
||||
bool IsGlobalLayer(size_t layer_idx) const {
|
||||
|
|
@ -151,7 +145,6 @@ struct Activations {
|
|||
const LayerConfig& layer_config;
|
||||
hwy::Divisor div_seq_len;
|
||||
bool is_griffin;
|
||||
float query_scale;
|
||||
const Extents2D none_ = Extents2D();
|
||||
const MatPadding pad_ = MatPadding::kOdd;
|
||||
|
||||
|
|
@ -182,9 +175,7 @@ struct Activations {
|
|||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
||||
// Storage for the last generated token from each query, passed to the next
|
||||
// Transformer() call.
|
||||
std::vector<int> gen_tokens; // one per query in the batch
|
||||
float query_scale;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -153,15 +153,12 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
|||
return pos - HWY_MIN(att_window_size - 1, pos);
|
||||
}
|
||||
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const size_t layer_idx,
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
NestedPools& pools) {
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
const hwy::Divisor div_queries(queries_pos.size());
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
|
|
@ -176,7 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
|||
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
||||
|
||||
// Statically partition token/query across packages.
|
||||
const size_t num_tq = num_tokens * div_queries.GetDivisor();
|
||||
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
|
||||
const IndexRangePartition tq_ranges =
|
||||
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
|
||||
ParallelizeOneRange(
|
||||
|
|
@ -185,17 +182,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
|||
pools.AllClusters(pkg_idx).Run(
|
||||
tq_range.begin(), tq_range.end(),
|
||||
[&](const size_t tq_idx, const size_t cluster_idx) {
|
||||
const size_t query_idx = div_queries.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_queries.Divide(tq_idx);
|
||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
|
||||
// Find the token position in the query and calculate
|
||||
// the range of cache positions to attend to.
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t start_pos =
|
||||
StartPos(pos, activations.weights_config, layer_idx);
|
||||
size_t last_pos = pos;
|
||||
const size_t prefix_end = queries_prefix_end[query_idx];
|
||||
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||
last_pos = prefix_end - 1;
|
||||
|
|
@ -235,14 +232,21 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
|||
});
|
||||
}
|
||||
|
||||
// Different functions use different naming conventions for the number of
|
||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||
|
||||
// Fills activations.q and writes to KV cache.
|
||||
static HWY_INLINE void ComputeQKV(
|
||||
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
||||
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations,
|
||||
const QBatch& qbatch, const int flags,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
const hwy::Divisor div_queries(queries_pos.size());
|
||||
const size_t num_interleaved = num_tokens * div_queries.GetDivisor();
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kv_heads = layer_config.kv_heads;
|
||||
|
|
@ -260,13 +264,12 @@ static HWY_INLINE void ComputeQKV(
|
|||
layer.qkv_einsum_w2.Rows()));
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t cache_pos =
|
||||
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||
kv_caches[query_idx].kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size);
|
||||
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
||||
}
|
||||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
|
|
@ -280,11 +283,11 @@ static HWY_INLINE void ComputeQKV(
|
|||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % kv_heads;
|
||||
const size_t interleaved_idx = task / kv_heads;
|
||||
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
|
|
@ -320,35 +323,18 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
activations.att_sums);
|
||||
}
|
||||
|
||||
// `queries_prefix_end` can be null (interpreted as all-zero) for standard
|
||||
// causal attention, and must be non-null for prefix-LM style attention.
|
||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const QueriesPos* queries_prefix_end,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches,
|
||||
MatMulEnv& env, int flags) {
|
||||
const size_t num_queries = queries_pos.size();
|
||||
HWY_DASSERT(num_queries <= kv_caches.size());
|
||||
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env, int flags) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
(void)layer_config; // only used in HWY_DASSERT
|
||||
|
||||
std::vector<size_t> queries_prefix_end_vec;
|
||||
QueriesPos queries_prefix_end_span;
|
||||
if (queries_prefix_end == nullptr) {
|
||||
queries_prefix_end_vec.assign(num_queries, 0);
|
||||
queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(),
|
||||
queries_prefix_end_vec.size());
|
||||
queries_prefix_end = &queries_prefix_end_span;
|
||||
}
|
||||
|
||||
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
|
||||
flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
|
||||
layer, activations, kv_caches, env.ctx.pools);
|
||||
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||
env.ctx.pools);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,18 +35,14 @@ namespace gcpp {
|
|||
const Activations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, \
|
||||
const QueriesPos& queries_pos, \
|
||||
const QueriesPos& queries_prefix_end, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
Activations& activations, \
|
||||
const KVCaches& kv_caches, NestedPools& pools); \
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, \
|
||||
Activations& activations, QBatch& qbatch, \
|
||||
NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
|
||||
const QueriesPos* queries_prefix_end, \
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
Activations& activations, const KVCaches& kv_caches, \
|
||||
MatMulEnv& env, int flags); \
|
||||
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, Activations& activations, \
|
||||
QBatch& qbatch, MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -205,7 +205,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
|||
// RuntimeConfig runtime_config = { ... }; // This was already defined
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
// Pass the populated image object to GenerateImageTokens
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
model.GenerateImageTokens(runtime_config,
|
||||
active_conversation->kv_cache->SeqLen(), image,
|
||||
image_tokens);
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
|
||||
ss.str("");
|
||||
|
|
|
|||
430
gemma/gemma.cc
430
gemma/gemma.cc
|
|
@ -45,7 +45,6 @@
|
|||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/model_store.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "io/blob_store.h"
|
||||
#include "io/io.h" // Path
|
||||
|
|
@ -62,14 +61,11 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
void Attention(LayerAttentionType type, size_t num_tokens,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx,
|
||||
layer, activations, kv_caches, env,
|
||||
GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env,
|
||||
/*flags=*/0);
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
|
|
@ -77,23 +73,23 @@ void Attention(LayerAttentionType type, size_t num_tokens,
|
|||
// so map `layer` to the Griffin layer index.
|
||||
const size_t griffin_layer =
|
||||
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
|
||||
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
|
||||
&layer, kv_caches, env);
|
||||
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
|
||||
env);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void TransformerLayer(
|
||||
const size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
|
||||
const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||
activations.pre_att_rms_out);
|
||||
|
||||
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
|
||||
layer_idx, layer, activations, kv_caches, env);
|
||||
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
|
||||
qbatch, env);
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||
activations.att_sums);
|
||||
|
|
@ -134,7 +130,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
|||
// calling application.
|
||||
// Returns new image_token_position.
|
||||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
|
||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
|
|
@ -142,14 +138,14 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
|||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
image_token_position < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
|
||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return image_token_position + 1;
|
||||
}
|
||||
|
||||
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
|
||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
|
||||
x.Cols() * x.ElementBytes());
|
||||
return image_token_position;
|
||||
}
|
||||
|
|
@ -169,34 +165,27 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
|||
const auto embedding_span =
|
||||
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
|
||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||
model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim);
|
||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim);
|
||||
});
|
||||
|
||||
if (model_config.absolute_pe) {
|
||||
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
|
||||
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
|
||||
}
|
||||
return image_token_position;
|
||||
}
|
||||
|
||||
// Incremented in-place by Prefill* and DecodeStepT.
|
||||
using QueriesMutablePos = hwy::Span<size_t>;
|
||||
|
||||
// Populates KV cache for batches of tokens from one query at a time. This is
|
||||
// called if prompts are longer than the query batch size, and also in
|
||||
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
||||
static HWY_NOINLINE void PrefillTBatch(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
||||
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.PrefillT");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(num_queries == queries_pos.size());
|
||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||
HWY_DASSERT(num_queries == kv_caches.size());
|
||||
|
||||
// Batches are important for amortizing loading weights over multiple tokens.
|
||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||
|
|
@ -210,19 +199,16 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||
|
||||
// For each query. `qi` is within the batch, not the global query index.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
|
||||
// Single query at a time, so pass slices of the spans because
|
||||
// GemmaAttention will only access the first KV cache and position.
|
||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
|
||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||
// One query at a time, batching will be the query's prompt tokens.
|
||||
QBatch qbatch_1 = qbatch.Single(qi);
|
||||
|
||||
const size_t prompt_size = queries_prompt[qi].size();
|
||||
const size_t prompt_size = qbatch_1.Prompt(0).size();
|
||||
// In autoregressive mode, we don't need to prefill the last token, so - 1.
|
||||
size_t prefill_this_query = prompt_size - 1;
|
||||
const size_t prefix_end_this_query = queries_prefix_end[qi];
|
||||
const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0);
|
||||
// We can't attend beyond the prompt_size.
|
||||
HWY_ASSERT(prefix_end_this_query <= prompt_size);
|
||||
// Special case: if the prefix includes the last token, we need to prefill
|
||||
|
|
@ -251,9 +237,9 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
// Fill activations.x (much faster than TransformerLayer).
|
||||
size_t image_token_position = 0;
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
image_token_position = EmbedMMToken(
|
||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
|
|
@ -262,18 +248,17 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
// Transformer with one batch of tokens from a single query.
|
||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||
++layer_idx) {
|
||||
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
|
||||
layer_idx, *weights.GetLayer(layer_idx), activations,
|
||||
single_kv_cache, env);
|
||||
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch_1, env);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||
if (pos_in_prompt < prompt_size - 1) {
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
||||
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
|
||||
} else {
|
||||
// The last token will be streamed later and we should only get here
|
||||
// if we need to attend to the last token because it is in the prefix.
|
||||
|
|
@ -281,7 +266,7 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
}
|
||||
}
|
||||
|
||||
queries_pos[qi] += tbatch_size;
|
||||
qbatch_1.MutablePos(0) += tbatch_size;
|
||||
} // for tbatch_start
|
||||
if (attend_to_last_token) {
|
||||
// We need to rewind the position for the last token that we only
|
||||
|
|
@ -290,148 +275,125 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
// decoding. Alternatives: (1) real masking; (2) always prefill the last
|
||||
// token and only generate the next one from the already prefilled
|
||||
// activations.
|
||||
queries_pos[qi] -= 1;
|
||||
qbatch_1.MutablePos(0) -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Embeds token and calls each TransformerLayer. `queries_token` is the previous
|
||||
// token from each query, and `queries_pos` are their position in the sequence.
|
||||
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
||||
// token-batched `PrefillTBatch`.
|
||||
static HWY_NOINLINE void Transformer(
|
||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(num_queries == queries_pos.size());
|
||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||
|
||||
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
const float token_f = queries_token[qi];
|
||||
runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f,
|
||||
1);
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const float token_f = qbatch.PrevToken(qi);
|
||||
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
|
||||
"tokens", -1, &token_f, 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||
}
|
||||
|
||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
|
||||
layer_idx, *weights.GetLayer(layer_idx), activations,
|
||||
kv_caches, env);
|
||||
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, qbatch, env);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(queries_pos, layer_idx, activations);
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||
activations);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Populates KV cache for the batch queries, one token at a time. Only called
|
||||
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
||||
static HWY_NOINLINE void PrefillQBatch(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const size_t max_prompt_size, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(num_queries == queries_pos.size());
|
||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||
HWY_DASSERT(num_queries == kv_caches.size());
|
||||
|
||||
hwy::BitSet4096<> prefill_active;
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
prefill_active.Set(qi);
|
||||
|
||||
HWY_DASSERT(queries_prefix_end[qi] == 0);
|
||||
(void)queries_prefix_end;
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
non_eos.Set(qi);
|
||||
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
|
||||
}
|
||||
non_eos = prefill_active;
|
||||
|
||||
// In autoregressive mode, we don't prefill the last token, hence - 1.
|
||||
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
|
||||
++pos_in_prompt) {
|
||||
// Streams that have already finished prefill no longer interleave/stream.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
if (pos_in_prompt >= queries_prompt[qi].size() - 1) {
|
||||
prefill_active.Clear(qi);
|
||||
activations.gen_tokens[qi] = config.eos_id;
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
int token = config.eos_id;
|
||||
if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) {
|
||||
token = qbatch.Prompt(qi)[pos_in_prompt];
|
||||
// Ignore StreamToken return value because requesting to stop does not
|
||||
// make sense during prefill.
|
||||
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi),
|
||||
token, 0.0f);
|
||||
}
|
||||
|
||||
qbatch.PrevToken(qi) = token;
|
||||
}
|
||||
|
||||
// Batch := interleaved tokens, one from each non-EOS query.
|
||||
prefill_active.Foreach([&](size_t qi) {
|
||||
activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt];
|
||||
});
|
||||
|
||||
// One token from each query in the batch. Increments queries_pos.
|
||||
// The input (PrevToken) is one token from each query in the batch.
|
||||
// Do not call DecodeStepT because it computes logits for token
|
||||
// probabilities, which are not required for the prompt tokens.
|
||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||
queries_pos, queries_prefix_end, config, runtime_config,
|
||||
weights, activations, kv_caches, env);
|
||||
|
||||
prefill_active.Foreach([&](size_t qi) {
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
// Ignore any user request to stop during prefill.
|
||||
(void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi],
|
||||
token, 0.0f);
|
||||
queries_pos[qi] += 1;
|
||||
});
|
||||
} // pos_in_prompt
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
}
|
||||
}
|
||||
|
||||
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
|
||||
// and updates `non_eos` if the query is at the end of its sequence.
|
||||
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
||||
const float prob, const ModelConfig& config,
|
||||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||
// query is at the end of its sequence.
|
||||
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
Activations& activations,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
HWY_DASSERT(non_eos.Get(qi));
|
||||
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
|
||||
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||
|
||||
// User decided to stop: set next token to primary EOS.
|
||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
|
||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
|
||||
qbatch.Pos(qi), token, prob))) {
|
||||
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||
token = config.eos_id;
|
||||
HWY_DASSERT(config.IsEOS(token));
|
||||
}
|
||||
|
||||
// Primary or secondary EOS: mark query as EOS.
|
||||
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||
qbatch.PrevToken(qi) = token;
|
||||
qbatch.MutablePos(qi) += 1;
|
||||
|
||||
activations.gen_tokens[qi] = token;
|
||||
// Primary or secondary EOS: mark query as EOS, but still increment (for
|
||||
// multi-turn, we should still keep the prior EOS).
|
||||
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||
}
|
||||
|
||||
// For a batch of queries, runs Transformer, computes logits, samples and
|
||||
// streams the token.
|
||||
static void DecodeStepT(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_mutable_pos,
|
||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
const SampleFunc& sample_token, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||
static void DecodeStepT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||
|
||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, queries_prefix_end, config, runtime_config,
|
||||
weights, activations, kv_caches, env);
|
||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(queries_mutable_pos, -1, activations);
|
||||
runtime_config.activations_observer(
|
||||
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -447,10 +409,8 @@ static void DecodeStepT(
|
|||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||
timing_info.NotifyGenerated();
|
||||
|
||||
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token,
|
||||
tp.prob, config, runtime_config, activations, non_eos);
|
||||
|
||||
if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1;
|
||||
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
||||
non_eos);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -477,46 +437,24 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
};
|
||||
}
|
||||
|
||||
// Generates one continuation for each query in `queries_prompt`, which is one
|
||||
// qbatch whose size is at most the `batch_size` passed to `activations` ctor.
|
||||
//
|
||||
// `queries_pos` stores the KV cache position for each query. In the first turn
|
||||
// of a chat, pos = 0; we increment each query's position after each token.
|
||||
//
|
||||
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
||||
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||
static void GenerateT(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end,
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`.
|
||||
HWY_ASSERT(num_queries == queries_pos_in.size());
|
||||
HWY_ASSERT(num_queries == queries_prefix_end.size());
|
||||
HWY_ASSERT(num_queries <= activations.x.Rows());
|
||||
HWY_ASSERT(num_queries == kv_caches.size());
|
||||
|
||||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
if (queries_pos_in[i] == 0) {
|
||||
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
if (qbatch.MutablePos(qi) == 0) {
|
||||
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
|
||||
}
|
||||
}
|
||||
|
||||
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||
queries_pos_in.cend());
|
||||
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||
queries_pos_copy.size());
|
||||
|
||||
size_t max_prompt_size = 0;
|
||||
bool all_prefix_end_are_zero = true;
|
||||
size_t prefill_tokens = 0;
|
||||
const size_t seq_len = kv_caches[0].SeqLen();
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
const PromptTokens& prompt = queries_prompt[qi];
|
||||
size_t prefill_tokens = 0; // only for timing.
|
||||
const size_t seq_len = qbatch.KV(0).SeqLen();
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const PromptTokens& prompt = qbatch.Prompt(qi);
|
||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||
|
||||
// Prefill stops before size - 1 because the last prompt token is the
|
||||
|
|
@ -526,43 +464,38 @@ static void GenerateT(
|
|||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||
|
||||
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
|
||||
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
|
||||
|
||||
// We use a single divisor, so all sequence lengths must be the same.
|
||||
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len);
|
||||
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||
}
|
||||
HWY_ASSERT(prefill_tokens < seq_len);
|
||||
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
|
||||
|
||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||
// qi loops anyway.
|
||||
hwy::BitSet4096<> non_eos;
|
||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||
|
||||
timing_info.prefill_start = hwy::platform::Now();
|
||||
// Batch over the larger of prompt length, or queries.
|
||||
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
|
||||
activations.SetBatchSize(num_queries); // required before PrefillQBatch
|
||||
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, max_prompt_size, config, runtime_config,
|
||||
weights, activations, kv_caches, env, non_eos);
|
||||
if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) {
|
||||
activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch
|
||||
PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations,
|
||||
qbatch, env, non_eos);
|
||||
} else {
|
||||
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, config, runtime_config, weights,
|
||||
activations, kv_caches, env, non_eos);
|
||||
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
|
||||
PrefillTBatch(config, runtime_config, weights, activations, qbatch, env,
|
||||
non_eos);
|
||||
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
|
||||
}
|
||||
HWY_DASSERT(num_queries == non_eos.Count());
|
||||
HWY_DASSERT(non_eos.Count() == qbatch.Size());
|
||||
timing_info.NotifyPrefill(prefill_tokens);
|
||||
// queries_pos have been incremented by Prefill.
|
||||
|
||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
const size_t last_token_pos_in_prompt =
|
||||
queries_mutable_pos[qi] - queries_pos_in[qi];
|
||||
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi],
|
||||
queries_prompt[qi][last_token_pos_in_prompt], 0.0f,
|
||||
config, runtime_config, activations, non_eos);
|
||||
// No incrementing queries_mutable_pos[qi].
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
||||
runtime_config, qbatch, non_eos);
|
||||
}
|
||||
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
|
|
@ -577,10 +510,8 @@ static void GenerateT(
|
|||
{
|
||||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, config, runtime_config, weights,
|
||||
sample_token, activations, kv_caches, env, non_eos,
|
||||
timing_info);
|
||||
DecodeStepT(config, runtime_config, weights, sample_token, activations,
|
||||
qbatch, env, non_eos, timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
}
|
||||
|
|
@ -591,61 +522,38 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
|||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, KVCache& kv_cache,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
constexpr size_t kNumQueries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.row_ptrs);
|
||||
|
||||
const size_t max_batch_size =
|
||||
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
|
||||
// TODO: move into Gemma?
|
||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
||||
|
||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
||||
QueriesPos queries_pos(&pos, kNumQueries);
|
||||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||
|
||||
GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end,
|
||||
config, runtime_config, weights, activations, kv_caches, env,
|
||||
AllQueries all_queries(prompt, pos, prefix_end,
|
||||
hwy::Span<KVCache>(&kv_cache, 1));
|
||||
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
|
||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
}
|
||||
|
||||
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
|
||||
// queries, and calls `GenerateT` on each batch.
|
||||
void GenerateBatchT(const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const ModelConfig& config,
|
||||
void GenerateBatchT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, const KVCaches& kv_caches,
|
||||
const ModelWeightsPtrs& weights, AllQueries& all_queries,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
Activations activations(config, max_batch_size,
|
||||
all_queries[0].kv_cache.SeqLen(), env.row_ptrs);
|
||||
|
||||
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||
const size_t max_batch_size =
|
||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
||||
|
||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||
qbatch_start += max_qbatch_size) {
|
||||
// Generate one batch of tokens from `qbatch_size` queries.
|
||||
const size_t qbatch_size =
|
||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
||||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
||||
qbatch_size);
|
||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end,
|
||||
config, runtime_config, weights, activations, qbatch_kv, env,
|
||||
for (size_t start = 0; start < all_queries.NumQueries();
|
||||
start += runtime_config.decode_qbatch_size) {
|
||||
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||
timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
void GenerateImageTokensT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const ModelWeightsPtrs& weights, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
|
|
@ -656,7 +564,8 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs);
|
||||
Activations prefill_activations(vit_config, num_tokens, num_tokens,
|
||||
env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
|
|
@ -714,36 +623,25 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches,
|
||||
AllQueries& all_queries,
|
||||
TimingInfo& timing_info) const {
|
||||
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
||||
QueriesPos queries_prefix_end_or_zeros = queries_prefix_end;
|
||||
std::vector<size_t> prefix_end_vec;
|
||||
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
|
||||
prefix_end_vec.resize(queries_prompt.size(), 0);
|
||||
queries_prefix_end_or_zeros =
|
||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
|
||||
queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(),
|
||||
runtime_config, weights_, kv_caches, env_, timing_info);
|
||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
||||
weights_, all_queries, env_,
|
||||
timing_info);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image,
|
||||
size_t seq_len, const Image& image,
|
||||
ImageTokens& image_tokens) const {
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
|
||||
model_.Config(), runtime_config, weights_, image, image_tokens, env_);
|
||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
|
||||
seq_len, weights_, image,
|
||||
image_tokens, env_);
|
||||
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
|
|
|||
142
gemma/gemma.h
142
gemma/gemma.h
|
|
@ -38,6 +38,129 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
struct PerQuery {
|
||||
PromptTokens prompt;
|
||||
|
||||
// Position in the KV cache: initially zero for the first turn, or when
|
||||
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
|
||||
size_t mutable_pos;
|
||||
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
|
||||
// which might differ from `prompt.size() - 1` for prefix-LM.
|
||||
size_t initial_pos;
|
||||
// Zero for causal attention, or the end of the prefix for prefix-LM style
|
||||
// attention in Paligemma.
|
||||
size_t prefix_end;
|
||||
|
||||
KVCache& kv_cache;
|
||||
|
||||
// Previous token generated for this query, or the last prompt token. Will be
|
||||
// fed into the next Transformer() call.
|
||||
int prev_token = 0;
|
||||
};
|
||||
|
||||
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
|
||||
struct AllQueries {
|
||||
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const hwy::Span<KVCache>& kv_caches) {
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompt,
|
||||
.mutable_pos = pos,
|
||||
.initial_pos = pos,
|
||||
.prefix_end = prefix_end,
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Batch of queries with initial position set to zero. Causal attention
|
||||
// is requested via empty or all-zero `prefix_end`.
|
||||
AllQueries(
|
||||
const hwy::Span<const PromptTokens>& prompts,
|
||||
const hwy::Span<KVCache>& kv_caches,
|
||||
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||
per_query_.reserve(kv_caches.size());
|
||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||
per_query_.push_back(PerQuery{
|
||||
.prompt = prompts[i],
|
||||
.mutable_pos = 0,
|
||||
.initial_pos = 0,
|
||||
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||
.kv_cache = kv_caches[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
size_t NumQueries() const { return per_query_.size(); }
|
||||
|
||||
PerQuery& operator[](size_t query_idx) {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
const PerQuery& operator[](size_t query_idx) const {
|
||||
HWY_DASSERT(query_idx < NumQueries());
|
||||
return per_query_[query_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<PerQuery> per_query_;
|
||||
};
|
||||
|
||||
// View into AllQueries: either a batch of queries, or a single query for use
|
||||
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
|
||||
// reference to AllQueries.
|
||||
class QBatch {
|
||||
public:
|
||||
QBatch(size_t start, size_t max_size, AllQueries& queries)
|
||||
: start_(start),
|
||||
max_size_(max_size),
|
||||
queries_(queries),
|
||||
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
|
||||
HWY_DASSERT(size_ != 0);
|
||||
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||
}
|
||||
|
||||
// Returns a single-query view starting at `qi` relative to this batch.
|
||||
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
|
||||
|
||||
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||
size_t Size() const { return size_; }
|
||||
|
||||
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||
size_t QueryIdx(size_t qi) const {
|
||||
HWY_DASSERT(qi < size_);
|
||||
return start_ + qi;
|
||||
}
|
||||
|
||||
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||
const PromptTokens& Prompt(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prompt;
|
||||
}
|
||||
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||
size_t InitialPos(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].initial_pos;
|
||||
}
|
||||
size_t PrefixEnd(size_t qi) const {
|
||||
return queries_[QueryIdx(qi)].prefix_end;
|
||||
}
|
||||
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||
|
||||
private:
|
||||
size_t start_;
|
||||
size_t max_size_;
|
||||
AllQueries& queries_;
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||
void NotifyPrefill(size_t tokens) {
|
||||
|
|
@ -100,8 +223,6 @@ struct TimingInfo {
|
|||
// Returns the `MatMulEnv` after calling `SetArgs`.
|
||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
||||
|
||||
using KVCaches = hwy::Span<KVCache>;
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
|
|
@ -133,24 +254,11 @@ class Gemma {
|
|||
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
||||
TimingInfo& timing_info) const;
|
||||
|
||||
// `queries_pos` are the positions in the KV cache. Users are responsible for
|
||||
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(runtime_config, queries_prompt, queries_pos,
|
||||
/*queries_prefix_end=*/{}, kv_caches, timing_info);
|
||||
}
|
||||
// For prefix-LM style attention, we can pass the ends of the prefixes.
|
||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, TimingInfo& timing_info) const;
|
||||
AllQueries& all_queries, TimingInfo& timing_info) const;
|
||||
|
||||
// Generates the image tokens by running the image encoder ViT.
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const Image& image, ImageTokens& image_tokens) const;
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -82,8 +82,9 @@ using ImageTokens = MatStorageT<float>;
|
|||
// true to continue generation.
|
||||
using StreamFunc = std::function<bool(int, float)>;
|
||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||
// For prompt tokens, probability is 0.0f.
|
||||
// StreamFunc should return false to stop generation and true to continue.
|
||||
// For prompt tokens, probability is 0.0f. Generation continues if this returns
|
||||
// true and stops if it returns false. Note that query_idx is absolute, not
|
||||
// relative to the batch.
|
||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||
// If not empty, AcceptFunc is called with token. It should return false for
|
||||
// tokens you don't want to generate and true for tokens you want to generate.
|
||||
|
|
@ -112,8 +113,8 @@ using ActivationsObserverFunc =
|
|||
// RuntimeConfig holds configuration for a single generation run.
|
||||
// TODO: move into InferenceArgs, use that directly.
|
||||
struct RuntimeConfig {
|
||||
// If not empty, batch_stream_token is called for each token in the batch,
|
||||
// instead of stream_token.
|
||||
// If non-null, `batch_stream_token` is called for each token in the batch,
|
||||
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
if (batch_stream_token) {
|
||||
return batch_stream_token(query_idx, pos, token, prob);
|
||||
|
|
@ -189,9 +190,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"developer/debug info).\n Default = 1.",
|
||||
1); // Changed verbosity level to 1 since it's user-facing
|
||||
|
||||
visitor(seq_len, "seq_len", size_t{2048},
|
||||
visitor(seq_len, "seq_len", size_t{8192},
|
||||
"Sequence length, capped by ModelConfig.max_seq_len.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||
|
|
|
|||
|
|
@ -39,16 +39,10 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
// Different functions use different naming conventions for the number of
|
||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||
|
||||
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
||||
size_t griffin_layer, Activations& activations,
|
||||
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
||||
const LayerWeightsPtrs* layer_weights,
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
|
@ -64,9 +58,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
|||
const size_t kHeadDim = model_dim / heads;
|
||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
|
||||
const size_t num_queries = queries_pos.size();
|
||||
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
|
||||
const size_t num_interleaved = num_tokens * num_queries;
|
||||
const size_t num_interleaved = num_tokens * qbatch.Size();
|
||||
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
|
||||
|
||||
// X / Y linear layers.
|
||||
// TODO: MatMul
|
||||
|
|
@ -91,17 +84,17 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
|||
// Conv1D.
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
|
||||
|
||||
// cache[i] = input at time t-i.
|
||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||
cache[0] = x;
|
||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||
cache[i] =
|
||||
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
|
||||
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
|
||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||
}
|
||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||
|
|
@ -127,16 +120,16 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
|||
// RGLRU
|
||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||
++interleaved_idx) {
|
||||
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
|
||||
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
|
||||
float* HWY_RESTRICT y = activations.griffin_y.Row(qi);
|
||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi);
|
||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi);
|
||||
float* HWY_RESTRICT rnn_state =
|
||||
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
|
||||
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
|
||||
|
||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
size_t head_offset = head * kHeadDim;
|
||||
|
|
|
|||
|
|
@ -26,13 +26,13 @@
|
|||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \
|
||||
size_t griffin_layer, Activations& activations, \
|
||||
const LayerWeightsPtrs* layer_weights, \
|
||||
const KVCaches& kv_caches, MatMulEnv& env); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
|
||||
const LayerWeightsPtrs* layer_weights, \
|
||||
Activations& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
|
|
|
|||
|
|
@ -120,7 +120,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
gemma.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||
image_tokens);
|
||||
if (inference.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
|
|
|
|||
|
|
@ -57,7 +57,8 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
||||
gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(),
|
||||
image, *image_tokens_);
|
||||
}
|
||||
|
||||
std::string GemmaReply(const std::string& prompt_text) const {
|
||||
|
|
@ -107,12 +108,11 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
TEST_F(PaliGemmaTest, QueryObjects) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const char* question = "answer en What objects are in the image?";
|
||||
const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224
|
||||
// 3B PT/Mix 224, 10B Mix 224
|
||||
const char* expected_substring = "Building, Tower";
|
||||
const Model model = s_env->GetGemma()->GetModelConfig().model;
|
||||
if (model == Model::PALIGEMMA2_3B_448) {
|
||||
expected_substring = "Lake.";
|
||||
} else if (model == Model::PALIGEMMA2_3B_224) {
|
||||
expected_substring = "Cloud, Water.";
|
||||
} else if (model == Model::PALIGEMMA2_10B_224) {
|
||||
expected_substring = "Building.";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -190,7 +190,8 @@ class GemmaModel {
|
|||
gcpp::MatPadding::kOdd));
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||
.verbosity = 0};
|
||||
gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_);
|
||||
gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(),
|
||||
c_image, *image_tokens_);
|
||||
}
|
||||
|
||||
// Generates a response to the given prompt, using the last set image.
|
||||
|
|
|
|||
Loading…
Reference in New Issue