mirror of https://github.com/google/gemma.cpp.git
Major refactor: clarify query_idx (global) vs qi. Refs #607
Fix missing pos increment for last prefill and check that in gemma_test. Thanks to @ufownl for pointing this out. Change argument lists to QBatch with accessors. Increase default seq_len to 8k. PiperOrigin-RevId: 771937385
This commit is contained in:
parent
2c72ff2aa5
commit
e5c81f64a1
|
|
@ -109,16 +109,19 @@ void GemmaEnv::QueryModel(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt) {
|
const QueriesPromptTokens& queries_prompt,
|
||||||
|
const hwy::Span<const size_t>& prefix_end) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_ASSERT(num_queries != 0);
|
HWY_ASSERT(num_queries != 0);
|
||||||
std::vector<QueryResult> res(num_queries);
|
std::vector<QueryResult> res(num_queries);
|
||||||
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
|
const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index,
|
||||||
size_t query_index, size_t pos,
|
const size_t pos,
|
||||||
int token, float) {
|
const int token, float) {
|
||||||
|
HWY_ASSERT(query_index < num_queries);
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
res[query_index].response.append(token_text);
|
res[query_index].response.append(token_text);
|
||||||
|
HWY_ASSERT(pos == res[query_index].tokens_generated);
|
||||||
res[query_index].tokens_generated += 1;
|
res[query_index].tokens_generated += 1;
|
||||||
if (res[query_index].tokens_generated ==
|
if (res[query_index].tokens_generated ==
|
||||||
queries_prompt[query_index].size()) {
|
queries_prompt[query_index].size()) {
|
||||||
|
|
@ -126,6 +129,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
if (runtime_config_.verbosity >= 2) {
|
if (runtime_config_.verbosity >= 2) {
|
||||||
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
|
||||||
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
runtime_config_.max_generated_tokens, runtime_config_.temperature,
|
||||||
|
|
@ -137,13 +141,11 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
while (kv_caches_.size() < num_queries) {
|
while (kv_caches_.size() < num_queries) {
|
||||||
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
|
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
|
||||||
}
|
}
|
||||||
|
const hwy::Span<KVCache> kv_caches(&kv_caches_[0], num_queries);
|
||||||
|
|
||||||
|
gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end);
|
||||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
gemma_.GenerateBatch(runtime_config_, all_queries, timing_info);
|
||||||
std::vector<size_t> queries_pos(num_queries, 0);
|
|
||||||
gemma_.GenerateBatch(runtime_config_, queries_prompt,
|
|
||||||
QueriesPos(queries_pos.data(), num_queries),
|
|
||||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,8 +88,10 @@ class GemmaEnv {
|
||||||
// Runs inference on the given input and returns the top-1 result string and
|
// Runs inference on the given input and returns the top-1 result string and
|
||||||
// the number of tokens that were generated.
|
// the number of tokens that were generated.
|
||||||
QueryResult QueryModel(const std::vector<int>& tokens);
|
QueryResult QueryModel(const std::vector<int>& tokens);
|
||||||
|
// The default prefix_end means "causal attention".
|
||||||
std::vector<QueryResult> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
const QueriesPromptTokens& queries_prompt);
|
const QueriesPromptTokens& queries_prompt,
|
||||||
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>());
|
||||||
// Adds turn structure to input, tokenizes and calls the above overload.
|
// Adds turn structure to input, tokenizes and calls the above overload.
|
||||||
QueryResult QueryModel(std::string& input);
|
QueryResult QueryModel(std::string& input);
|
||||||
std::vector<QueryResult> BatchQueryModel(
|
std::vector<QueryResult> BatchQueryModel(
|
||||||
|
|
|
||||||
|
|
@ -101,9 +101,11 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
const ModelConfig& config = model->GetModelConfig();
|
const ModelConfig& config = model->GetModelConfig();
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](size_t query_idx, size_t pos, int token, float) {
|
||||||
if (config.IsEOS(token)) return true;
|
HWY_ASSERT(query_idx == 0);
|
||||||
|
HWY_ASSERT(pos == abs_pos);
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
|
if (config.IsEOS(token)) return true;
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
|
|
@ -115,7 +117,7 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
.temperature = 0.0f,
|
.temperature = 0.0f,
|
||||||
.gen = &s_env->MutableGen(),
|
.gen = &s_env->MutableGen(),
|
||||||
.verbosity = 2,
|
.verbosity = 2,
|
||||||
.stream_token = stream_token,
|
.batch_stream_token = stream_token,
|
||||||
};
|
};
|
||||||
TimingInfo timing_info{.verbosity = 0};
|
TimingInfo timing_info{.verbosity = 0};
|
||||||
// First "say" something slightly unusual.
|
// First "say" something slightly unusual.
|
||||||
|
|
|
||||||
|
|
@ -42,13 +42,12 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Activations {
|
struct Activations {
|
||||||
Activations(const ModelConfig& config, size_t batch_size,
|
Activations(const ModelConfig& config, size_t batch_size, size_t seq_len,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: weights_config(config),
|
: weights_config(config),
|
||||||
layer_config(config.layer_configs[0]),
|
layer_config(config.layer_configs[0]),
|
||||||
div_seq_len(static_cast<uint32_t>(config.max_seq_len)),
|
div_seq_len(static_cast<uint32_t>(seq_len)),
|
||||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||||
query_scale(ChooseQueryScale(config)),
|
|
||||||
|
|
||||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||||
|
|
@ -63,10 +62,7 @@ struct Activations {
|
||||||
|
|
||||||
pre_att_rms_out("pre_att_rms_out",
|
pre_att_rms_out("pre_att_rms_out",
|
||||||
Extents2D(batch_size, config.model_dim), pad_),
|
Extents2D(batch_size, config.model_dim), pad_),
|
||||||
att("att",
|
att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_),
|
||||||
Extents2D(batch_size,
|
|
||||||
layer_config.heads * div_seq_len.GetDivisor()),
|
|
||||||
pad_),
|
|
||||||
att_out(
|
att_out(
|
||||||
"att_out",
|
"att_out",
|
||||||
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
|
||||||
|
|
@ -99,7 +95,7 @@ struct Activations {
|
||||||
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
||||||
1000000.0)),
|
1000000.0)),
|
||||||
|
|
||||||
gen_tokens(batch_size) {
|
query_scale(ChooseQueryScale(config)) {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
|
|
@ -138,8 +134,6 @@ struct Activations {
|
||||||
griffin_gate_x.OverrideRows(batch_size);
|
griffin_gate_x.OverrideRows(batch_size);
|
||||||
griffin_multiplier.OverrideRows(batch_size);
|
griffin_multiplier.OverrideRows(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
gen_tokens.resize(batch_size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsGlobalLayer(size_t layer_idx) const {
|
bool IsGlobalLayer(size_t layer_idx) const {
|
||||||
|
|
@ -151,7 +145,6 @@ struct Activations {
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
hwy::Divisor div_seq_len;
|
hwy::Divisor div_seq_len;
|
||||||
bool is_griffin;
|
bool is_griffin;
|
||||||
float query_scale;
|
|
||||||
const Extents2D none_ = Extents2D();
|
const Extents2D none_ = Extents2D();
|
||||||
const MatPadding pad_ = MatPadding::kOdd;
|
const MatPadding pad_ = MatPadding::kOdd;
|
||||||
|
|
||||||
|
|
@ -182,9 +175,7 @@ struct Activations {
|
||||||
MatStorageT<float> inv_timescale;
|
MatStorageT<float> inv_timescale;
|
||||||
MatStorageT<float> inv_timescale_global;
|
MatStorageT<float> inv_timescale_global;
|
||||||
|
|
||||||
// Storage for the last generated token from each query, passed to the next
|
float query_scale;
|
||||||
// Transformer() call.
|
|
||||||
std::vector<int> gen_tokens; // one per query in the batch
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -153,15 +153,12 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
||||||
return pos - HWY_MIN(att_window_size - 1, pos);
|
return pos - HWY_MIN(att_window_size - 1, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens,
|
void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
const QueriesPos& queries_pos,
|
|
||||||
const QueriesPos& queries_prefix_end,
|
|
||||||
const size_t layer_idx,
|
|
||||||
const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer,
|
||||||
Activations& activations, const KVCaches& kv_caches,
|
Activations& activations, QBatch& qbatch,
|
||||||
NestedPools& pools) {
|
NestedPools& pools) {
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||||
const hwy::Divisor div_queries(queries_pos.size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
|
|
@ -176,7 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
// For each head/token/query, compute Q.K, softmax, and weighted V.
|
||||||
|
|
||||||
// Statically partition token/query across packages.
|
// Statically partition token/query across packages.
|
||||||
const size_t num_tq = num_tokens * div_queries.GetDivisor();
|
const size_t num_tq = num_tokens * div_qbatch.GetDivisor();
|
||||||
const IndexRangePartition tq_ranges =
|
const IndexRangePartition tq_ranges =
|
||||||
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
|
StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1);
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
|
|
@ -185,17 +182,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
pools.AllClusters(pkg_idx).Run(
|
pools.AllClusters(pkg_idx).Run(
|
||||||
tq_range.begin(), tq_range.end(),
|
tq_range.begin(), tq_range.end(),
|
||||||
[&](const size_t tq_idx, const size_t cluster_idx) {
|
[&](const size_t tq_idx, const size_t cluster_idx) {
|
||||||
const size_t query_idx = div_queries.Remainder(tq_idx);
|
const size_t qi = div_qbatch.Remainder(tq_idx);
|
||||||
const size_t batch_idx = div_queries.Divide(tq_idx);
|
const size_t batch_idx = div_qbatch.Divide(tq_idx);
|
||||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
|
|
||||||
// Find the token position in the query and calculate
|
// Find the token position in the query and calculate
|
||||||
// the range of cache positions to attend to.
|
// the range of cache positions to attend to.
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
StartPos(pos, activations.weights_config, layer_idx);
|
StartPos(pos, activations.weights_config, layer_idx);
|
||||||
size_t last_pos = pos;
|
size_t last_pos = pos;
|
||||||
const size_t prefix_end = queries_prefix_end[query_idx];
|
const size_t prefix_end = qbatch.PrefixEnd(qi);
|
||||||
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
if (prefix_end > 0 && prefix_end - 1 > last_pos) {
|
||||||
// last_pos in QDotK and WeightedSumV is inclusive.
|
// last_pos in QDotK and WeightedSumV is inclusive.
|
||||||
last_pos = prefix_end - 1;
|
last_pos = prefix_end - 1;
|
||||||
|
|
@ -235,14 +232,21 @@ void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Different functions use different naming conventions for the number of
|
||||||
|
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
||||||
|
// count `num_interleaved`. Functions that are query-dependent, such as
|
||||||
|
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
||||||
|
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
||||||
|
|
||||||
// Fills activations.q and writes to KV cache.
|
// Fills activations.q and writes to KV cache.
|
||||||
static HWY_INLINE void ComputeQKV(
|
static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
|
||||||
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
|
const LayerWeightsPtrs& layer,
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
Activations& activations,
|
||||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
const QBatch& qbatch, const int flags,
|
||||||
|
MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.Attention.QKV");
|
PROFILER_ZONE("Gen.Attention.QKV");
|
||||||
const hwy::Divisor div_queries(queries_pos.size());
|
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||||
const size_t num_interleaved = num_tokens * div_queries.GetDivisor();
|
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
const size_t kv_heads = layer_config.kv_heads;
|
const size_t kv_heads = layer_config.kv_heads;
|
||||||
|
|
@ -260,13 +264,12 @@ static HWY_INLINE void ComputeQKV(
|
||||||
layer.qkv_einsum_w2.Rows()));
|
layer.qkv_einsum_w2.Rows()));
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t cache_pos =
|
const size_t cache_pos =
|
||||||
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx);
|
||||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||||
kv_caches[query_idx].kv_cache.Row(cache_pos) +
|
qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size);
|
||||||
layer_idx * cache_layer_size);
|
|
||||||
}
|
}
|
||||||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||||
|
|
@ -280,11 +283,11 @@ static HWY_INLINE void ComputeQKV(
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t head = task % kv_heads;
|
const size_t head = task % kv_heads;
|
||||||
const size_t interleaved_idx = task / kv_heads;
|
const size_t interleaved_idx = task / kv_heads;
|
||||||
const size_t query_idx = div_queries.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_queries.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
auto& kv_cache = qbatch.KV(qi).kv_cache;
|
||||||
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||||
layer_idx * cache_layer_size +
|
layer_idx * cache_layer_size +
|
||||||
head * qkv_dim * 2;
|
head * qkv_dim * 2;
|
||||||
|
|
@ -320,35 +323,18 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
activations.att_sums);
|
activations.att_sums);
|
||||||
}
|
}
|
||||||
|
|
||||||
// `queries_prefix_end` can be null (interpreted as all-zero) for standard
|
void GemmaAttention(size_t num_tokens, const size_t layer_idx,
|
||||||
// causal attention, and must be non-null for prefix-LM style attention.
|
const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
QBatch& qbatch, MatMulEnv& env, int flags) {
|
||||||
const QueriesPos* queries_prefix_end,
|
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
|
||||||
Activations& activations, const KVCaches& kv_caches,
|
|
||||||
MatMulEnv& env, int flags) {
|
|
||||||
const size_t num_queries = queries_pos.size();
|
|
||||||
HWY_DASSERT(num_queries <= kv_caches.size());
|
|
||||||
|
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
HWY_DASSERT(!layer_config.IsMHA()); // No longer supported.
|
||||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
(void)layer_config; // only used in HWY_DASSERT
|
(void)layer_config; // only used in HWY_DASSERT
|
||||||
|
|
||||||
std::vector<size_t> queries_prefix_end_vec;
|
ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env);
|
||||||
QueriesPos queries_prefix_end_span;
|
DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch,
|
||||||
if (queries_prefix_end == nullptr) {
|
env.ctx.pools);
|
||||||
queries_prefix_end_vec.assign(num_queries, 0);
|
|
||||||
queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(),
|
|
||||||
queries_prefix_end_vec.size());
|
|
||||||
queries_prefix_end = &queries_prefix_end_span;
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
|
|
||||||
flags, env);
|
|
||||||
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
|
|
||||||
layer, activations, kv_caches, env.ctx.pools);
|
|
||||||
SumHeads(layer, activations, env);
|
SumHeads(layer, activations, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,18 +35,14 @@ namespace gcpp {
|
||||||
const Activations& activations, float* HWY_RESTRICT att, \
|
const Activations& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att_out); \
|
float* HWY_RESTRICT att_out); \
|
||||||
\
|
\
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, \
|
void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \
|
||||||
const QueriesPos& queries_pos, \
|
const LayerWeightsPtrs& layer, \
|
||||||
const QueriesPos& queries_prefix_end, \
|
Activations& activations, QBatch& qbatch, \
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
NestedPools& pools); \
|
||||||
Activations& activations, \
|
|
||||||
const KVCaches& kv_caches, NestedPools& pools); \
|
|
||||||
\
|
\
|
||||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
|
void GemmaAttention(size_t num_tokens, const size_t layer_idx, \
|
||||||
const QueriesPos* queries_prefix_end, \
|
const LayerWeightsPtrs& layer, Activations& activations, \
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer, \
|
QBatch& qbatch, MatMulEnv& env, int flags); \
|
||||||
Activations& activations, const KVCaches& kv_caches, \
|
|
||||||
MatMulEnv& env, int flags); \
|
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
|
||||||
// RuntimeConfig runtime_config = { ... }; // This was already defined
|
// RuntimeConfig runtime_config = { ... }; // This was already defined
|
||||||
double image_tokens_start = hwy::platform::Now();
|
double image_tokens_start = hwy::platform::Now();
|
||||||
// Pass the populated image object to GenerateImageTokens
|
// Pass the populated image object to GenerateImageTokens
|
||||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
model.GenerateImageTokens(runtime_config,
|
||||||
|
active_conversation->kv_cache->SeqLen(), image,
|
||||||
|
image_tokens);
|
||||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||||
|
|
||||||
ss.str("");
|
ss.str("");
|
||||||
|
|
|
||||||
430
gemma/gemma.cc
430
gemma/gemma.cc
|
|
@ -45,7 +45,6 @@
|
||||||
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/model_store.h"
|
#include "gemma/model_store.h"
|
||||||
#include "gemma/tokenizer.h"
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "io/blob_store.h"
|
#include "io/blob_store.h"
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
|
|
@ -62,14 +61,11 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
void Attention(LayerAttentionType type, size_t num_tokens,
|
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||||
const QueriesPos& queries_pos,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
|
||||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx,
|
GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env,
|
||||||
layer, activations, kv_caches, env,
|
|
||||||
/*flags=*/0);
|
/*flags=*/0);
|
||||||
} else {
|
} else {
|
||||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
|
|
@ -77,23 +73,23 @@ void Attention(LayerAttentionType type, size_t num_tokens,
|
||||||
// so map `layer` to the Griffin layer index.
|
// so map `layer` to the Griffin layer index.
|
||||||
const size_t griffin_layer =
|
const size_t griffin_layer =
|
||||||
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
|
activations.weights_config.NumLayersOfTypeBefore(type, layer_idx);
|
||||||
GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations,
|
GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch,
|
||||||
&layer, kv_caches, env);
|
env);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void TransformerLayer(
|
static HWY_NOINLINE void TransformerLayer(const size_t num_tokens,
|
||||||
const size_t num_tokens, const QueriesPos& queries_pos,
|
const size_t layer_idx,
|
||||||
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
const LayerWeightsPtrs& layer,
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
Activations& activations,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
QBatch& qbatch, MatMulEnv& env) {
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
|
||||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||||
activations.pre_att_rms_out);
|
activations.pre_att_rms_out);
|
||||||
|
|
||||||
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
|
Attention(layer_config.type, num_tokens, layer_idx, layer, activations,
|
||||||
layer_idx, layer, activations, kv_caches, env);
|
qbatch, env);
|
||||||
|
|
||||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||||
activations.att_sums);
|
activations.att_sums);
|
||||||
|
|
@ -134,7 +130,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
||||||
// calling application.
|
// calling application.
|
||||||
// Returns new image_token_position.
|
// Returns new image_token_position.
|
||||||
static HWY_NOINLINE size_t
|
static HWY_NOINLINE size_t
|
||||||
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
|
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
|
||||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||||
size_t image_token_position = 0) {
|
size_t image_token_position = 0) {
|
||||||
|
|
@ -142,14 +138,14 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
||||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||||
image_tokens != nullptr && token == -2 &&
|
image_tokens != nullptr && token == -2 &&
|
||||||
image_token_position < image_tokens->Rows()) {
|
image_token_position < image_tokens->Rows()) {
|
||||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
|
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
|
||||||
x.Cols() * x.ElementBytes());
|
x.Cols() * x.ElementBytes());
|
||||||
return image_token_position + 1;
|
return image_token_position + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
|
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
|
||||||
x.Cols() * x.ElementBytes());
|
x.Cols() * x.ElementBytes());
|
||||||
return image_token_position;
|
return image_token_position;
|
||||||
}
|
}
|
||||||
|
|
@ -169,34 +165,27 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
||||||
const auto embedding_span =
|
const auto embedding_span =
|
||||||
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
|
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
||||||
model_dim);
|
model_dim);
|
||||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim);
|
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (model_config.absolute_pe) {
|
if (model_config.absolute_pe) {
|
||||||
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
|
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
|
||||||
}
|
}
|
||||||
return image_token_position;
|
return image_token_position;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Incremented in-place by Prefill* and DecodeStepT.
|
|
||||||
using QueriesMutablePos = hwy::Span<size_t>;
|
|
||||||
|
|
||||||
// Populates KV cache for batches of tokens from one query at a time. This is
|
// Populates KV cache for batches of tokens from one query at a time. This is
|
||||||
// called if prompts are longer than the query batch size, and also in
|
// called if prompts are longer than the query batch size, and also in
|
||||||
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
||||||
static HWY_NOINLINE void PrefillTBatch(
|
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const RuntimeConfig& runtime_config,
|
||||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
const ModelWeightsPtrs& weights,
|
||||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
Activations& activations, QBatch& qbatch,
|
||||||
const ModelWeightsPtrs& weights, Activations& activations,
|
MatMulEnv& env,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
hwy::BitSet4096<>& non_eos) {
|
||||||
PROFILER_ZONE("Gen.PrefillT");
|
PROFILER_ZONE("Gen.PrefillT");
|
||||||
const size_t num_queries = queries_prompt.size();
|
|
||||||
HWY_DASSERT(num_queries == queries_pos.size());
|
|
||||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
|
||||||
HWY_DASSERT(num_queries == kv_caches.size());
|
|
||||||
|
|
||||||
// Batches are important for amortizing loading weights over multiple tokens.
|
// Batches are important for amortizing loading weights over multiple tokens.
|
||||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||||
|
|
@ -210,19 +199,16 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||||
|
|
||||||
// For each query. `qi` is within the batch, not the global query index.
|
// For each query. `qi` is within the batch, not the global query index.
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
non_eos.Set(qi);
|
non_eos.Set(qi);
|
||||||
|
|
||||||
// Single query at a time, so pass slices of the spans because
|
// One query at a time, batching will be the query's prompt tokens.
|
||||||
// GemmaAttention will only access the first KV cache and position.
|
QBatch qbatch_1 = qbatch.Single(qi);
|
||||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
|
||||||
QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1);
|
|
||||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
|
||||||
|
|
||||||
const size_t prompt_size = queries_prompt[qi].size();
|
const size_t prompt_size = qbatch_1.Prompt(0).size();
|
||||||
// In autoregressive mode, we don't need to prefill the last token, so - 1.
|
// In autoregressive mode, we don't need to prefill the last token, so - 1.
|
||||||
size_t prefill_this_query = prompt_size - 1;
|
size_t prefill_this_query = prompt_size - 1;
|
||||||
const size_t prefix_end_this_query = queries_prefix_end[qi];
|
const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0);
|
||||||
// We can't attend beyond the prompt_size.
|
// We can't attend beyond the prompt_size.
|
||||||
HWY_ASSERT(prefix_end_this_query <= prompt_size);
|
HWY_ASSERT(prefix_end_this_query <= prompt_size);
|
||||||
// Special case: if the prefix includes the last token, we need to prefill
|
// Special case: if the prefix includes the last token, we need to prefill
|
||||||
|
|
@ -251,9 +237,9 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
// Fill activations.x (much faster than TransformerLayer).
|
// Fill activations.x (much faster than TransformerLayer).
|
||||||
size_t image_token_position = 0;
|
size_t image_token_position = 0;
|
||||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
const size_t pos = queries_pos[qi] + ti;
|
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||||
const size_t pos_in_prompt = tbatch_start + ti;
|
const size_t pos_in_prompt = tbatch_start + ti;
|
||||||
const int token = queries_prompt[qi][pos_in_prompt];
|
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||||
image_token_position = EmbedMMToken(
|
image_token_position = EmbedMMToken(
|
||||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||||
runtime_config.image_tokens, image_token_position);
|
runtime_config.image_tokens, image_token_position);
|
||||||
|
|
@ -262,18 +248,17 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
// Transformer with one batch of tokens from a single query.
|
// Transformer with one batch of tokens from a single query.
|
||||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||||
++layer_idx) {
|
++layer_idx) {
|
||||||
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
|
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
||||||
layer_idx, *weights.GetLayer(layer_idx), activations,
|
activations, qbatch_1, env);
|
||||||
single_kv_cache, env);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
const size_t pos = queries_pos[qi] + ti;
|
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||||
const size_t pos_in_prompt = tbatch_start + ti;
|
const size_t pos_in_prompt = tbatch_start + ti;
|
||||||
const int token = queries_prompt[qi][pos_in_prompt];
|
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||||
if (pos_in_prompt < prompt_size - 1) {
|
if (pos_in_prompt < prompt_size - 1) {
|
||||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
|
||||||
} else {
|
} else {
|
||||||
// The last token will be streamed later and we should only get here
|
// The last token will be streamed later and we should only get here
|
||||||
// if we need to attend to the last token because it is in the prefix.
|
// if we need to attend to the last token because it is in the prefix.
|
||||||
|
|
@ -281,7 +266,7 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
queries_pos[qi] += tbatch_size;
|
qbatch_1.MutablePos(0) += tbatch_size;
|
||||||
} // for tbatch_start
|
} // for tbatch_start
|
||||||
if (attend_to_last_token) {
|
if (attend_to_last_token) {
|
||||||
// We need to rewind the position for the last token that we only
|
// We need to rewind the position for the last token that we only
|
||||||
|
|
@ -290,148 +275,125 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
// decoding. Alternatives: (1) real masking; (2) always prefill the last
|
// decoding. Alternatives: (1) real masking; (2) always prefill the last
|
||||||
// token and only generate the next one from the already prefilled
|
// token and only generate the next one from the already prefilled
|
||||||
// activations.
|
// activations.
|
||||||
queries_pos[qi] -= 1;
|
qbatch_1.MutablePos(0) -= 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embeds token and calls each TransformerLayer. `queries_token` is the previous
|
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||||
// token from each query, and `queries_pos` are their position in the sequence.
|
|
||||||
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
||||||
// token-batched `PrefillTBatch`.
|
// token-batched `PrefillTBatch`.
|
||||||
static HWY_NOINLINE void Transformer(
|
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
const RuntimeConfig& runtime_config,
|
||||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
const ModelWeightsPtrs& weights,
|
||||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
Activations& activations, QBatch& qbatch,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
const size_t num_queries = queries_token.size();
|
|
||||||
HWY_DASSERT(num_queries == queries_pos.size());
|
|
||||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
|
||||||
|
|
||||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const float token_f = queries_token[qi];
|
const float token_f = qbatch.PrevToken(qi);
|
||||||
runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f,
|
runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi),
|
||||||
1);
|
"tokens", -1, &token_f, 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi),
|
||||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
/*pos_in_prompt=*/0, config, weights, activations.x);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||||
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
|
TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx),
|
||||||
layer_idx, *weights.GetLayer(layer_idx), activations,
|
activations, qbatch, env);
|
||||||
kv_caches, env);
|
|
||||||
|
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
runtime_config.activations_observer(queries_pos, layer_idx, activations);
|
runtime_config.activations_observer(
|
||||||
|
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx,
|
||||||
|
activations);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populates KV cache for the batch queries, one token at a time. Only called
|
// Populates KV cache for the batch queries, one token at a time. Only called
|
||||||
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
||||||
static HWY_NOINLINE void PrefillQBatch(
|
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const ModelConfig& config,
|
||||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
const RuntimeConfig& runtime_config,
|
||||||
const size_t max_prompt_size, const ModelConfig& config,
|
const ModelWeightsPtrs& weights,
|
||||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
Activations& activations, QBatch& qbatch,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
MatMulEnv& env,
|
||||||
hwy::BitSet4096<>& non_eos) {
|
hwy::BitSet4096<>& non_eos) {
|
||||||
PROFILER_ZONE("Gen.Prefill");
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
const size_t num_queries = queries_prompt.size();
|
|
||||||
HWY_DASSERT(num_queries == queries_pos.size());
|
|
||||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
|
||||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
|
||||||
HWY_DASSERT(num_queries == kv_caches.size());
|
|
||||||
|
|
||||||
hwy::BitSet4096<> prefill_active;
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
non_eos.Set(qi);
|
||||||
prefill_active.Set(qi);
|
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
|
||||||
|
|
||||||
HWY_DASSERT(queries_prefix_end[qi] == 0);
|
|
||||||
(void)queries_prefix_end;
|
|
||||||
}
|
}
|
||||||
non_eos = prefill_active;
|
|
||||||
|
|
||||||
// In autoregressive mode, we don't prefill the last token, hence - 1.
|
// In autoregressive mode, we don't prefill the last token, hence - 1.
|
||||||
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
|
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
|
||||||
++pos_in_prompt) {
|
++pos_in_prompt) {
|
||||||
// Streams that have already finished prefill no longer interleave/stream.
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
int token = config.eos_id;
|
||||||
if (pos_in_prompt >= queries_prompt[qi].size() - 1) {
|
if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) {
|
||||||
prefill_active.Clear(qi);
|
token = qbatch.Prompt(qi)[pos_in_prompt];
|
||||||
activations.gen_tokens[qi] = config.eos_id;
|
// Ignore StreamToken return value because requesting to stop does not
|
||||||
|
// make sense during prefill.
|
||||||
|
(void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi),
|
||||||
|
token, 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
qbatch.PrevToken(qi) = token;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Batch := interleaved tokens, one from each non-EOS query.
|
// The input (PrevToken) is one token from each query in the batch.
|
||||||
prefill_active.Foreach([&](size_t qi) {
|
|
||||||
activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt];
|
|
||||||
});
|
|
||||||
|
|
||||||
// One token from each query in the batch. Increments queries_pos.
|
|
||||||
// Do not call DecodeStepT because it computes logits for token
|
// Do not call DecodeStepT because it computes logits for token
|
||||||
// probabilities, which are not required for the prompt tokens.
|
// probabilities, which are not required for the prompt tokens.
|
||||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
queries_pos, queries_prefix_end, config, runtime_config,
|
}
|
||||||
weights, activations, kv_caches, env);
|
|
||||||
|
|
||||||
prefill_active.Foreach([&](size_t qi) {
|
|
||||||
const int token = queries_prompt[qi][pos_in_prompt];
|
|
||||||
// Ignore any user request to stop during prefill.
|
|
||||||
(void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi],
|
|
||||||
token, 0.0f);
|
|
||||||
queries_pos[qi] += 1;
|
|
||||||
});
|
|
||||||
} // pos_in_prompt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
|
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||||
// and updates `non_eos` if the query is at the end of its sequence.
|
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||||
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
// query is at the end of its sequence.
|
||||||
const float prob, const ModelConfig& config,
|
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||||
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
Activations& activations,
|
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
|
||||||
hwy::BitSet4096<>& non_eos) {
|
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||||
HWY_DASSERT(non_eos.Get(qi));
|
|
||||||
|
|
||||||
// User decided to stop: set next token to primary EOS.
|
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
|
||||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
|
qbatch.Pos(qi), token, prob))) {
|
||||||
|
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||||
token = config.eos_id;
|
token = config.eos_id;
|
||||||
HWY_DASSERT(config.IsEOS(token));
|
HWY_DASSERT(config.IsEOS(token));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Primary or secondary EOS: mark query as EOS.
|
qbatch.PrevToken(qi) = token;
|
||||||
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
qbatch.MutablePos(qi) += 1;
|
||||||
|
|
||||||
activations.gen_tokens[qi] = token;
|
// Primary or secondary EOS: mark query as EOS, but still increment (for
|
||||||
|
// multi-turn, we should still keep the prior EOS).
|
||||||
|
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For a batch of queries, runs Transformer, computes logits, samples and
|
// For a batch of queries, runs Transformer, computes logits, samples and
|
||||||
// streams the token.
|
// streams the token.
|
||||||
static void DecodeStepT(
|
static void DecodeStepT(const ModelConfig& config,
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const RuntimeConfig& runtime_config,
|
||||||
const QueriesMutablePos& queries_mutable_pos,
|
const ModelWeightsPtrs& weights,
|
||||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
const SampleFunc& sample_token,
|
||||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
Activations& activations, QBatch& qbatch,
|
||||||
const SampleFunc& sample_token, Activations& activations,
|
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||||
const size_t num_queries = queries_prompt.size();
|
|
||||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
|
||||||
|
|
||||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
queries_mutable_pos, queries_prefix_end, config, runtime_config,
|
|
||||||
weights, activations, kv_caches, env);
|
|
||||||
|
|
||||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||||
|
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
runtime_config.activations_observer(queries_mutable_pos, -1, activations);
|
runtime_config.activations_observer(
|
||||||
|
QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
@ -447,10 +409,8 @@ static void DecodeStepT(
|
||||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||||
timing_info.NotifyGenerated();
|
timing_info.NotifyGenerated();
|
||||||
|
|
||||||
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token,
|
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
||||||
tp.prob, config, runtime_config, activations, non_eos);
|
non_eos);
|
||||||
|
|
||||||
if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -477,46 +437,24 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates one continuation for each query in `queries_prompt`, which is one
|
// Decode: generates one continuation token for each query in `qbatch`.
|
||||||
// qbatch whose size is at most the `batch_size` passed to `activations` ctor.
|
static void GenerateT(const ModelConfig& config,
|
||||||
//
|
const RuntimeConfig& runtime_config,
|
||||||
// `queries_pos` stores the KV cache position for each query. In the first turn
|
const ModelWeightsPtrs& weights, Activations& activations,
|
||||||
// of a chat, pos = 0; we increment each query's position after each token.
|
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
//
|
|
||||||
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
|
||||||
// `StreamFunc` gets the global query index, not relative to the batch.
|
|
||||||
static void GenerateT(
|
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
|
||||||
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end,
|
|
||||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
|
||||||
const ModelWeightsPtrs& weights, Activations& activations,
|
|
||||||
const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) {
|
|
||||||
const size_t num_queries = queries_prompt.size();
|
|
||||||
HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`.
|
|
||||||
HWY_ASSERT(num_queries == queries_pos_in.size());
|
|
||||||
HWY_ASSERT(num_queries == queries_prefix_end.size());
|
|
||||||
HWY_ASSERT(num_queries <= activations.x.Rows());
|
|
||||||
HWY_ASSERT(num_queries == kv_caches.size());
|
|
||||||
|
|
||||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
if (queries_pos_in[i] == 0) {
|
if (qbatch.MutablePos(qi) == 0) {
|
||||||
kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models.
|
qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy so we can increment without requiring users to pass in a mutable span.
|
|
||||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
|
||||||
queries_pos_in.cend());
|
|
||||||
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
|
||||||
queries_pos_copy.size());
|
|
||||||
|
|
||||||
size_t max_prompt_size = 0;
|
size_t max_prompt_size = 0;
|
||||||
bool all_prefix_end_are_zero = true;
|
bool all_prefix_end_are_zero = true;
|
||||||
size_t prefill_tokens = 0;
|
size_t prefill_tokens = 0; // only for timing.
|
||||||
const size_t seq_len = kv_caches[0].SeqLen();
|
const size_t seq_len = qbatch.KV(0).SeqLen();
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const PromptTokens& prompt = queries_prompt[qi];
|
const PromptTokens& prompt = qbatch.Prompt(qi);
|
||||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||||
|
|
||||||
// Prefill stops before size - 1 because the last prompt token is the
|
// Prefill stops before size - 1 because the last prompt token is the
|
||||||
|
|
@ -526,43 +464,38 @@ static void GenerateT(
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||||
|
|
||||||
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
|
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
|
||||||
|
|
||||||
// We use a single divisor, so all sequence lengths must be the same.
|
// We use a single divisor, so all sequence lengths must be the same.
|
||||||
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len);
|
HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len);
|
||||||
}
|
}
|
||||||
HWY_ASSERT(prefill_tokens < seq_len);
|
HWY_ASSERT(prefill_tokens < seq_len);
|
||||||
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
|
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
|
||||||
|
|
||||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||||
// qi loops anyway.
|
// qi loops anyway.
|
||||||
hwy::BitSet4096<> non_eos;
|
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||||
|
|
||||||
timing_info.prefill_start = hwy::platform::Now();
|
timing_info.prefill_start = hwy::platform::Now();
|
||||||
// Batch over the larger of prompt length, or queries.
|
// Batch over the larger of prompt length, or queries.
|
||||||
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
|
if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) {
|
||||||
activations.SetBatchSize(num_queries); // required before PrefillQBatch
|
activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch
|
||||||
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations,
|
||||||
queries_prefix_end, max_prompt_size, config, runtime_config,
|
qbatch, env, non_eos);
|
||||||
weights, activations, kv_caches, env, non_eos);
|
|
||||||
} else {
|
} else {
|
||||||
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
PrefillTBatch(config, runtime_config, weights, activations, qbatch, env,
|
||||||
queries_prefix_end, config, runtime_config, weights,
|
non_eos);
|
||||||
activations, kv_caches, env, non_eos);
|
activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch.
|
||||||
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
|
|
||||||
}
|
}
|
||||||
HWY_DASSERT(num_queries == non_eos.Count());
|
HWY_DASSERT(non_eos.Count() == qbatch.Size());
|
||||||
timing_info.NotifyPrefill(prefill_tokens);
|
timing_info.NotifyPrefill(prefill_tokens);
|
||||||
// queries_pos have been incremented by Prefill.
|
// queries_pos have been incremented by Prefill.
|
||||||
|
|
||||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const size_t last_token_pos_in_prompt =
|
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||||
queries_mutable_pos[qi] - queries_pos_in[qi];
|
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
||||||
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi],
|
runtime_config, qbatch, non_eos);
|
||||||
queries_prompt[qi][last_token_pos_in_prompt], 0.0f,
|
|
||||||
config, runtime_config, activations, non_eos);
|
|
||||||
// No incrementing queries_mutable_pos[qi].
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
|
|
@ -577,10 +510,8 @@ static void GenerateT(
|
||||||
{
|
{
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
|
DecodeStepT(config, runtime_config, weights, sample_token, activations,
|
||||||
queries_prefix_end, config, runtime_config, weights,
|
qbatch, env, non_eos, timing_info);
|
||||||
sample_token, activations, kv_caches, env, non_eos,
|
|
||||||
timing_info);
|
|
||||||
}
|
}
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
}
|
}
|
||||||
|
|
@ -591,61 +522,38 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights, KVCache& kv_cache,
|
const ModelWeightsPtrs& weights, KVCache& kv_cache,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
constexpr size_t kNumQueries = 1;
|
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||||
const size_t qbatch_start = 0;
|
kv_cache.SeqLen(), env.row_ptrs);
|
||||||
|
|
||||||
const size_t max_batch_size =
|
AllQueries all_queries(prompt, pos, prefix_end,
|
||||||
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
|
hwy::Span<KVCache>(&kv_cache, 1));
|
||||||
// TODO: move into Gemma?
|
QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries);
|
||||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||||
|
|
||||||
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
|
|
||||||
QueriesPos queries_pos(&pos, kNumQueries);
|
|
||||||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
|
||||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
|
||||||
|
|
||||||
GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end,
|
|
||||||
config, runtime_config, weights, activations, kv_caches, env,
|
|
||||||
timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
|
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
|
||||||
// queries, and calls `GenerateT` on each batch.
|
// queries, and calls `GenerateT` on each batch.
|
||||||
void GenerateBatchT(const QueriesPromptTokens& queries_prompt,
|
void GenerateBatchT(const ModelConfig& config,
|
||||||
const QueriesPos& queries_pos,
|
|
||||||
const QueriesPos& queries_prefix_end,
|
|
||||||
const ModelConfig& config,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights, const KVCaches& kv_caches,
|
const ModelWeightsPtrs& weights, AllQueries& all_queries,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
runtime_config.prefill_tbatch_size);
|
||||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
Activations activations(config, max_batch_size,
|
||||||
|
all_queries[0].kv_cache.SeqLen(), env.row_ptrs);
|
||||||
|
|
||||||
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
for (size_t start = 0; start < all_queries.NumQueries();
|
||||||
const size_t max_batch_size =
|
start += runtime_config.decode_qbatch_size) {
|
||||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries);
|
||||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
// Generate a batch of one token for each of `qbatch.Size()` queries.
|
||||||
|
GenerateT(config, runtime_config, weights, activations, qbatch, env,
|
||||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
|
||||||
qbatch_start += max_qbatch_size) {
|
|
||||||
// Generate one batch of tokens from `qbatch_size` queries.
|
|
||||||
const size_t qbatch_size =
|
|
||||||
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
|
|
||||||
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
|
|
||||||
qbatch_size);
|
|
||||||
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
|
|
||||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
|
||||||
qbatch_size);
|
|
||||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
|
||||||
GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end,
|
|
||||||
config, runtime_config, weights, activations, qbatch_kv, env,
|
|
||||||
timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateImageTokensT(const ModelConfig& config,
|
void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config, size_t seq_len,
|
||||||
const ModelWeightsPtrs& weights, const Image& image,
|
const ModelWeightsPtrs& weights, const Image& image,
|
||||||
ImageTokens& image_tokens, MatMulEnv& env) {
|
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||||
if (config.vit_config.layer_configs.empty()) {
|
if (config.vit_config.layer_configs.empty()) {
|
||||||
|
|
@ -656,7 +564,8 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const size_t num_tokens = vit_config.max_seq_len;
|
const size_t num_tokens = vit_config.max_seq_len;
|
||||||
prefill_runtime_config.prefill_tbatch_size =
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs);
|
Activations prefill_activations(vit_config, num_tokens, num_tokens,
|
||||||
|
env.row_ptrs);
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||||
prefill_activations, env);
|
prefill_activations, env);
|
||||||
|
|
@ -714,36 +623,25 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
const QueriesPromptTokens& queries_prompt,
|
AllQueries& all_queries,
|
||||||
const QueriesPos& queries_pos,
|
|
||||||
const QueriesPos& queries_prefix_end,
|
|
||||||
const KVCaches& kv_caches,
|
|
||||||
TimingInfo& timing_info) const {
|
TimingInfo& timing_info) const {
|
||||||
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
|
||||||
QueriesPos queries_prefix_end_or_zeros = queries_prefix_end;
|
|
||||||
std::vector<size_t> prefix_end_vec;
|
|
||||||
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
|
|
||||||
prefix_end_vec.resize(queries_prompt.size(), 0);
|
|
||||||
queries_prefix_end_or_zeros =
|
|
||||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
|
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config,
|
||||||
queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(),
|
weights_, all_queries, env_,
|
||||||
runtime_config, weights_, kv_caches, env_, timing_info);
|
timing_info);
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||||
const Image& image,
|
size_t seq_len, const Image& image,
|
||||||
ImageTokens& image_tokens) const {
|
ImageTokens& image_tokens) const {
|
||||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
|
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config,
|
||||||
model_.Config(), runtime_config, weights_, image, image_tokens, env_);
|
seq_len, weights_, image,
|
||||||
|
image_tokens, env_);
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
142
gemma/gemma.h
142
gemma/gemma.h
|
|
@ -38,6 +38,129 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
struct PerQuery {
|
||||||
|
PromptTokens prompt;
|
||||||
|
|
||||||
|
// Position in the KV cache: initially zero for the first turn, or when
|
||||||
|
// multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`.
|
||||||
|
size_t mutable_pos;
|
||||||
|
// Allows computing the last prefill token as `mutable_pos - initial_pos`,
|
||||||
|
// which might differ from `prompt.size() - 1` for prefix-LM.
|
||||||
|
size_t initial_pos;
|
||||||
|
// Zero for causal attention, or the end of the prefix for prefix-LM style
|
||||||
|
// attention in Paligemma.
|
||||||
|
size_t prefix_end;
|
||||||
|
|
||||||
|
KVCache& kv_cache;
|
||||||
|
|
||||||
|
// Previous token generated for this query, or the last prompt token. Will be
|
||||||
|
// fed into the next Transformer() call.
|
||||||
|
int prev_token = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`.
|
||||||
|
struct AllQueries {
|
||||||
|
// For `GenerateSingleT`: same prompt/pos, replicated for each KV cache.
|
||||||
|
AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
|
const hwy::Span<KVCache>& kv_caches) {
|
||||||
|
per_query_.reserve(kv_caches.size());
|
||||||
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
|
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||||
|
per_query_.push_back(PerQuery{
|
||||||
|
.prompt = prompt,
|
||||||
|
.mutable_pos = pos,
|
||||||
|
.initial_pos = pos,
|
||||||
|
.prefix_end = prefix_end,
|
||||||
|
.kv_cache = kv_caches[i],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch of queries with initial position set to zero. Causal attention
|
||||||
|
// is requested via empty or all-zero `prefix_end`.
|
||||||
|
AllQueries(
|
||||||
|
const hwy::Span<const PromptTokens>& prompts,
|
||||||
|
const hwy::Span<KVCache>& kv_caches,
|
||||||
|
const hwy::Span<const size_t>& prefix_end = hwy::Span<const size_t>()) {
|
||||||
|
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||||
|
HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0);
|
||||||
|
per_query_.reserve(kv_caches.size());
|
||||||
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
|
HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen());
|
||||||
|
per_query_.push_back(PerQuery{
|
||||||
|
.prompt = prompts[i],
|
||||||
|
.mutable_pos = 0,
|
||||||
|
.initial_pos = 0,
|
||||||
|
.prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i],
|
||||||
|
.kv_cache = kv_caches[i],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t NumQueries() const { return per_query_.size(); }
|
||||||
|
|
||||||
|
PerQuery& operator[](size_t query_idx) {
|
||||||
|
HWY_DASSERT(query_idx < NumQueries());
|
||||||
|
return per_query_[query_idx];
|
||||||
|
}
|
||||||
|
const PerQuery& operator[](size_t query_idx) const {
|
||||||
|
HWY_DASSERT(query_idx < NumQueries());
|
||||||
|
return per_query_[query_idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<PerQuery> per_query_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// View into AllQueries: either a batch of queries, or a single query for use
|
||||||
|
// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a
|
||||||
|
// reference to AllQueries.
|
||||||
|
class QBatch {
|
||||||
|
public:
|
||||||
|
QBatch(size_t start, size_t max_size, AllQueries& queries)
|
||||||
|
: start_(start),
|
||||||
|
max_size_(max_size),
|
||||||
|
queries_(queries),
|
||||||
|
size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) {
|
||||||
|
HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`.
|
||||||
|
HWY_DASSERT(size_ != 0);
|
||||||
|
HWY_DASSERT(start_ + size_ <= queries_.NumQueries());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a single-query view starting at `qi` relative to this batch.
|
||||||
|
QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); }
|
||||||
|
|
||||||
|
// How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`.
|
||||||
|
size_t Size() const { return size_; }
|
||||||
|
|
||||||
|
// Returns index for use with `AllQueries` and `BatchStreamToken`.
|
||||||
|
size_t QueryIdx(size_t qi) const {
|
||||||
|
HWY_DASSERT(qi < size_);
|
||||||
|
return start_ + qi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accessor functions to bridge the previous SoA and current AoS layout.
|
||||||
|
const PromptTokens& Prompt(size_t qi) const {
|
||||||
|
return queries_[QueryIdx(qi)].prompt;
|
||||||
|
}
|
||||||
|
size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||||
|
size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; }
|
||||||
|
size_t InitialPos(size_t qi) const {
|
||||||
|
return queries_[QueryIdx(qi)].initial_pos;
|
||||||
|
}
|
||||||
|
size_t PrefixEnd(size_t qi) const {
|
||||||
|
return queries_[QueryIdx(qi)].prefix_end;
|
||||||
|
}
|
||||||
|
KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; }
|
||||||
|
int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t start_;
|
||||||
|
size_t max_size_;
|
||||||
|
AllQueries& queries_;
|
||||||
|
size_t size_;
|
||||||
|
};
|
||||||
|
|
||||||
struct TimingInfo {
|
struct TimingInfo {
|
||||||
// be sure to populate prefill_start before calling NotifyPrefill.
|
// be sure to populate prefill_start before calling NotifyPrefill.
|
||||||
void NotifyPrefill(size_t tokens) {
|
void NotifyPrefill(size_t tokens) {
|
||||||
|
|
@ -100,8 +223,6 @@ struct TimingInfo {
|
||||||
// Returns the `MatMulEnv` after calling `SetArgs`.
|
// Returns the `MatMulEnv` after calling `SetArgs`.
|
||||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
||||||
|
|
||||||
using KVCaches = hwy::Span<KVCache>;
|
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||||
|
|
@ -133,24 +254,11 @@ class Gemma {
|
||||||
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
size_t pos, size_t prefix_end, KVCache& kv_cache,
|
||||||
TimingInfo& timing_info) const;
|
TimingInfo& timing_info) const;
|
||||||
|
|
||||||
// `queries_pos` are the positions in the KV cache. Users are responsible for
|
|
||||||
// incrementing them in `BatchStreamFunc`, or setting to zero for single-turn.
|
|
||||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
void GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
const QueriesPromptTokens& queries_prompt,
|
AllQueries& all_queries, TimingInfo& timing_info) const;
|
||||||
const QueriesPos& queries_pos, const KVCaches& kv_caches,
|
|
||||||
TimingInfo& timing_info) const {
|
|
||||||
GenerateBatch(runtime_config, queries_prompt, queries_pos,
|
|
||||||
/*queries_prefix_end=*/{}, kv_caches, timing_info);
|
|
||||||
}
|
|
||||||
// For prefix-LM style attention, we can pass the ends of the prefixes.
|
|
||||||
void GenerateBatch(const RuntimeConfig& runtime_config,
|
|
||||||
const QueriesPromptTokens& queries_prompt,
|
|
||||||
const QueriesPos& queries_pos,
|
|
||||||
const QueriesPos& queries_prefix_end,
|
|
||||||
const KVCaches& kv_caches, TimingInfo& timing_info) const;
|
|
||||||
|
|
||||||
// Generates the image tokens by running the image encoder ViT.
|
// Generates the image tokens by running the image encoder ViT.
|
||||||
void GenerateImageTokens(const RuntimeConfig& runtime_config,
|
void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len,
|
||||||
const Image& image, ImageTokens& image_tokens) const;
|
const Image& image, ImageTokens& image_tokens) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -82,8 +82,9 @@ using ImageTokens = MatStorageT<float>;
|
||||||
// true to continue generation.
|
// true to continue generation.
|
||||||
using StreamFunc = std::function<bool(int, float)>;
|
using StreamFunc = std::function<bool(int, float)>;
|
||||||
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
// BatchStreamFunc is called with (query_idx, pos, token, probability).
|
||||||
// For prompt tokens, probability is 0.0f.
|
// For prompt tokens, probability is 0.0f. Generation continues if this returns
|
||||||
// StreamFunc should return false to stop generation and true to continue.
|
// true and stops if it returns false. Note that query_idx is absolute, not
|
||||||
|
// relative to the batch.
|
||||||
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
using BatchStreamFunc = std::function<bool(size_t, size_t, int, float)>;
|
||||||
// If not empty, AcceptFunc is called with token. It should return false for
|
// If not empty, AcceptFunc is called with token. It should return false for
|
||||||
// tokens you don't want to generate and true for tokens you want to generate.
|
// tokens you don't want to generate and true for tokens you want to generate.
|
||||||
|
|
@ -112,8 +113,8 @@ using ActivationsObserverFunc =
|
||||||
// RuntimeConfig holds configuration for a single generation run.
|
// RuntimeConfig holds configuration for a single generation run.
|
||||||
// TODO: move into InferenceArgs, use that directly.
|
// TODO: move into InferenceArgs, use that directly.
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
// If not empty, batch_stream_token is called for each token in the batch,
|
// If non-null, `batch_stream_token` is called for each token in the batch,
|
||||||
// instead of stream_token.
|
// otherwise `stream_token`. `query_idx` is absolute, not batch-relative.
|
||||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||||
if (batch_stream_token) {
|
if (batch_stream_token) {
|
||||||
return batch_stream_token(query_idx, pos, token, prob);
|
return batch_stream_token(query_idx, pos, token, prob);
|
||||||
|
|
@ -189,9 +190,9 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"developer/debug info).\n Default = 1.",
|
"developer/debug info).\n Default = 1.",
|
||||||
1); // Changed verbosity level to 1 since it's user-facing
|
1); // Changed verbosity level to 1 since it's user-facing
|
||||||
|
|
||||||
visitor(seq_len, "seq_len", size_t{2048},
|
visitor(seq_len, "seq_len", size_t{8192},
|
||||||
"Sequence length, capped by ModelConfig.max_seq_len.");
|
"Sequence length, capped by ModelConfig.max_seq_len.");
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{4096},
|
||||||
"Maximum number of tokens to generate.");
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256},
|
||||||
|
|
|
||||||
|
|
@ -39,16 +39,10 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
// Different functions use different naming conventions for the number of
|
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer,
|
||||||
// tokens. Functions that are query-independent, such as RMSNorm*, call the
|
|
||||||
// count `num_interleaved`. Functions that are query-dependent, such as
|
|
||||||
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
|
|
||||||
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
|
|
||||||
|
|
||||||
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
|
||||||
size_t griffin_layer, Activations& activations,
|
|
||||||
const LayerWeightsPtrs* layer_weights,
|
const LayerWeightsPtrs* layer_weights,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
Activations& activations, QBatch& qbatch,
|
||||||
|
MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
@ -64,9 +58,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
||||||
const size_t kHeadDim = model_dim / heads;
|
const size_t kHeadDim = model_dim / heads;
|
||||||
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
const size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||||
|
|
||||||
const size_t num_queries = queries_pos.size();
|
const size_t num_interleaved = num_tokens * qbatch.Size();
|
||||||
const hwy::Divisor div_num_q(static_cast<uint32_t>(num_queries));
|
const hwy::Divisor div_qbatch(static_cast<uint32_t>(qbatch.Size()));
|
||||||
const size_t num_interleaved = num_tokens * num_queries;
|
|
||||||
|
|
||||||
// X / Y linear layers.
|
// X / Y linear layers.
|
||||||
// TODO: MatMul
|
// TODO: MatMul
|
||||||
|
|
@ -91,17 +84,17 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
||||||
// Conv1D.
|
// Conv1D.
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
|
||||||
|
|
||||||
// cache[i] = input at time t-i.
|
// cache[i] = input at time t-i.
|
||||||
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
float* HWY_RESTRICT cache[kMaxConv1DWidth];
|
||||||
cache[0] = x;
|
cache[0] = x;
|
||||||
for (size_t i = 1; i < conv_1d_width; i++) {
|
for (size_t i = 1; i < conv_1d_width; i++) {
|
||||||
cache[i] =
|
cache[i] =
|
||||||
kv_caches[query_idx].conv1d_cache.Row(griffin_layer) +
|
qbatch.KV(qi).conv1d_cache.Row(griffin_layer) +
|
||||||
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||||
|
|
@ -127,16 +120,16 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens,
|
||||||
// RGLRU
|
// RGLRU
|
||||||
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
|
||||||
++interleaved_idx) {
|
++interleaved_idx) {
|
||||||
const size_t query_idx = div_num_q.Remainder(interleaved_idx);
|
const size_t qi = div_qbatch.Remainder(interleaved_idx);
|
||||||
const size_t batch_idx = div_num_q.Divide(interleaved_idx);
|
const size_t batch_idx = div_qbatch.Divide(interleaved_idx);
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = qbatch.Pos(qi) + batch_idx;
|
||||||
|
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx);
|
float* HWY_RESTRICT x = activations.griffin_x.Row(qi);
|
||||||
float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx);
|
float* HWY_RESTRICT y = activations.griffin_y.Row(qi);
|
||||||
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx);
|
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi);
|
||||||
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx);
|
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi);
|
||||||
float* HWY_RESTRICT rnn_state =
|
float* HWY_RESTRICT rnn_state =
|
||||||
kv_caches[query_idx].rglru_cache.Row(griffin_layer);
|
qbatch.KV(qi).rglru_cache.Row(griffin_layer);
|
||||||
|
|
||||||
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
size_t head_offset = head * kHeadDim;
|
size_t head_offset = head * kHeadDim;
|
||||||
|
|
|
||||||
|
|
@ -26,13 +26,13 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||||
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
|
#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \
|
void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \
|
||||||
size_t griffin_layer, Activations& activations, \
|
const LayerWeightsPtrs* layer_weights, \
|
||||||
const LayerWeightsPtrs* layer_weights, \
|
Activations& activations, QBatch& qbatch, \
|
||||||
const KVCaches& kv_caches, MatMulEnv& env); \
|
MatMulEnv& env); \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
// Function declarations for each SIMD target. Allows direct call from the
|
// Function declarations for each SIMD target. Allows direct call from the
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
.verbosity = inference.verbosity,
|
.verbosity = inference.verbosity,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
double image_tokens_start = hwy::platform::Now();
|
double image_tokens_start = hwy::platform::Now();
|
||||||
gemma.GenerateImageTokens(runtime_config, image, image_tokens);
|
gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image,
|
||||||
|
image_tokens);
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,8 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
|
RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(),
|
||||||
.verbosity = 0};
|
.verbosity = 0};
|
||||||
gemma.GenerateImageTokens(runtime_config, image, *image_tokens_);
|
gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(),
|
||||||
|
image, *image_tokens_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GemmaReply(const std::string& prompt_text) const {
|
std::string GemmaReply(const std::string& prompt_text) const {
|
||||||
|
|
@ -107,12 +108,11 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
TEST_F(PaliGemmaTest, QueryObjects) {
|
TEST_F(PaliGemmaTest, QueryObjects) {
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
const char* question = "answer en What objects are in the image?";
|
const char* question = "answer en What objects are in the image?";
|
||||||
const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224
|
// 3B PT/Mix 224, 10B Mix 224
|
||||||
|
const char* expected_substring = "Building, Tower";
|
||||||
const Model model = s_env->GetGemma()->GetModelConfig().model;
|
const Model model = s_env->GetGemma()->GetModelConfig().model;
|
||||||
if (model == Model::PALIGEMMA2_3B_448) {
|
if (model == Model::PALIGEMMA2_3B_448) {
|
||||||
expected_substring = "Lake.";
|
expected_substring = "Lake.";
|
||||||
} else if (model == Model::PALIGEMMA2_3B_224) {
|
|
||||||
expected_substring = "Cloud, Water.";
|
|
||||||
} else if (model == Model::PALIGEMMA2_10B_224) {
|
} else if (model == Model::PALIGEMMA2_10B_224) {
|
||||||
expected_substring = "Building.";
|
expected_substring = "Building.";
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,8 @@ class GemmaModel {
|
||||||
gcpp::MatPadding::kOdd));
|
gcpp::MatPadding::kOdd));
|
||||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||||
.verbosity = 0};
|
.verbosity = 0};
|
||||||
gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_);
|
gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(),
|
||||||
|
c_image, *image_tokens_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates a response to the given prompt, using the last set image.
|
// Generates a response to the given prompt, using the last set image.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue