mirror of https://github.com/google/gemma.cpp.git
6x large-batch, short-prompt prefill speedup
Parallelize over queries instead of tokens introduce non_eos so we only iterate over not yet EOS queries; remove TokenStreamer. move RMSNormInplaceBatched out of Transformer to call the latter from prefill Consistent arg order. Fix gemma_test EOS handling which (caught by msan), remove from tokenizer.h Also add output to gemma_batch_bench, fix name PiperOrigin-RevId: 769676106
This commit is contained in:
parent
d7b23d532a
commit
ec02726cf7
|
|
@ -33,7 +33,7 @@ namespace {
|
||||||
// non-local static variables with dtors.
|
// non-local static variables with dtors.
|
||||||
GemmaEnv* s_env = nullptr;
|
GemmaEnv* s_env = nullptr;
|
||||||
|
|
||||||
class GemmaTest : public ::testing::Test {
|
class GemmaBatchBench : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::string> BatchGemmaReply(
|
std::vector<std::string> BatchGemmaReply(
|
||||||
const std::vector<std::string>& inputs) {
|
const std::vector<std::string>& inputs) {
|
||||||
|
|
@ -48,7 +48,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GemmaTest, RandomQuestionsBatched) {
|
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
|
||||||
const std::vector<std::string> questions = {
|
const std::vector<std::string> questions = {
|
||||||
{"Write me a poem about Australia?"},
|
{"Write me a poem about Australia?"},
|
||||||
{"What's the history of Denmark?"},
|
{"What's the history of Denmark?"},
|
||||||
|
|
@ -103,6 +103,7 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
fprintf(stderr, "GemmaEnv setup..\n");
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
gcpp::s_env = &env;
|
gcpp::s_env = &env;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
if (token == EOS_ID) return true;
|
if (config.IsEOS(token)) return true;
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@
|
||||||
#include "util/allocator.h" // Allocator
|
#include "util/allocator.h" // Allocator
|
||||||
#include "util/basics.h" // BF16
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h" // MatStorageT
|
#include "util/mat.h" // MatStorageT
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -48,6 +49,7 @@ struct Activations {
|
||||||
seq_len(config.seq_len),
|
seq_len(config.seq_len),
|
||||||
cache_pos_size(config.CachePosSize()),
|
cache_pos_size(config.CachePosSize()),
|
||||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||||
|
query_scale(ChooseQueryScale(config)),
|
||||||
|
|
||||||
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
x("x", Extents2D(batch_size, config.model_dim), pad_),
|
||||||
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
|
||||||
|
|
@ -96,7 +98,7 @@ struct Activations {
|
||||||
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
|
||||||
1000000.0)),
|
1000000.0)),
|
||||||
|
|
||||||
query_scale(ChooseQueryScale(config)) {
|
gen_tokens(batch_size) {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
// For MatMul outputs, precompute their row pointers.
|
// For MatMul outputs, precompute their row pointers.
|
||||||
|
|
@ -114,6 +116,7 @@ struct Activations {
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetBatchSize(size_t batch_size) {
|
void SetBatchSize(size_t batch_size) {
|
||||||
|
PROFILER_ZONE("SetBatchSize");
|
||||||
x.OverrideRows(batch_size);
|
x.OverrideRows(batch_size);
|
||||||
q.OverrideRows(batch_size);
|
q.OverrideRows(batch_size);
|
||||||
logits.OverrideRows(batch_size);
|
logits.OverrideRows(batch_size);
|
||||||
|
|
@ -134,13 +137,16 @@ struct Activations {
|
||||||
griffin_gate_x.OverrideRows(batch_size);
|
griffin_gate_x.OverrideRows(batch_size);
|
||||||
griffin_multiplier.OverrideRows(batch_size);
|
griffin_multiplier.OverrideRows(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gen_tokens.resize(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelConfig& weights_config;
|
const ModelConfig& weights_config;
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
size_t seq_len;
|
size_t seq_len;
|
||||||
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
|
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
|
||||||
bool is_griffin = false;
|
bool is_griffin;
|
||||||
|
float query_scale;
|
||||||
const Extents2D none_ = Extents2D();
|
const Extents2D none_ = Extents2D();
|
||||||
const MatPadding pad_ = MatPadding::kOdd;
|
const MatPadding pad_ = MatPadding::kOdd;
|
||||||
|
|
||||||
|
|
@ -171,7 +177,9 @@ struct Activations {
|
||||||
MatStorageT<float> inv_timescale;
|
MatStorageT<float> inv_timescale;
|
||||||
MatStorageT<float> inv_timescale_global;
|
MatStorageT<float> inv_timescale_global;
|
||||||
|
|
||||||
float query_scale;
|
// Storage for the last generated token from each query, passed to the next
|
||||||
|
// Transformer() call.
|
||||||
|
std::vector<int> gen_tokens; // one per query in the batch
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
412
gemma/gemma.cc
412
gemma/gemma.cc
|
|
@ -181,21 +181,24 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
|
||||||
return image_token_position;
|
return image_token_position;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefill() and Transformer() increment positions in-place.
|
// Incremented in-place by Prefill* and DecodeStepT.
|
||||||
using QueriesMutablePos = hwy::Span<size_t>;
|
using QueriesMutablePos = hwy::Span<size_t>;
|
||||||
|
|
||||||
// Populates KV cache for batches of tokens from one query at a time.
|
// Populates KV cache for batches of tokens from one query at a time. This is
|
||||||
static HWY_NOINLINE void Prefill(
|
// called if prompts are longer than the query batch size, and also in
|
||||||
|
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
||||||
|
static HWY_NOINLINE void PrefillTBatch(
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||||
const hwy::Divisor& div_seq_len, const ModelConfig& config,
|
const hwy::Divisor& div_seq_len, const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||||
PROFILER_ZONE("Gen.Prefill");
|
hwy::BitSet4096<>& non_eos) {
|
||||||
|
PROFILER_ZONE("Gen.PrefillT");
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
HWY_DASSERT(num_queries == queries_pos.size());
|
||||||
HWY_DASSERT(queries_prefix_end.size() == num_queries);
|
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||||
HWY_DASSERT(kv_caches.size() == num_queries);
|
HWY_DASSERT(num_queries == kv_caches.size());
|
||||||
|
|
||||||
// Batches are important for amortizing loading weights over multiple tokens.
|
// Batches are important for amortizing loading weights over multiple tokens.
|
||||||
// This is possible in prefill because we know all tokens beforehand, whereas
|
// This is possible in prefill because we know all tokens beforehand, whereas
|
||||||
|
|
@ -203,13 +206,15 @@ static HWY_NOINLINE void Prefill(
|
||||||
// a query requires that preceding batches already wrote to the KV cache,
|
// a query requires that preceding batches already wrote to the KV cache,
|
||||||
// hence we sequentially loop over token batches. We can reduce the number of
|
// hence we sequentially loop over token batches. We can reduce the number of
|
||||||
// iterations by increasing the batch size, but this also increases arithmetic
|
// iterations by increasing the batch size, but this also increases arithmetic
|
||||||
// intensity, and so we are eventually compute-limited. We could devote some
|
// intensity, and so we are eventually compute-limited. TransformerLayer uses
|
||||||
// threads to parallelizing over queries, but for simplicity we assign them
|
// all available threads, so we do not also parallelize over queries, but note
|
||||||
// all to MatMul.
|
// that PrefillQBatch uses queries as the batch dimension.
|
||||||
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
|
||||||
|
|
||||||
// For each query. `qi` is within the batch, not the global query index.
|
// For each query. `qi` is within the batch, not the global query index.
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
non_eos.Set(qi);
|
||||||
|
|
||||||
// Single query at a time, so pass slices of the spans because
|
// Single query at a time, so pass slices of the spans because
|
||||||
// GemmaAttention will only access the first KV cache and position.
|
// GemmaAttention will only access the first KV cache and position.
|
||||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||||
|
|
@ -292,30 +297,34 @@ static HWY_NOINLINE void Prefill(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates one token for each query. `queries_token` is the previous token
|
// Embeds token and calls each TransformerLayer. `queries_token` is the previous
|
||||||
// from each query, and `queries_pos` are their position in the sequence.
|
// token from each query, and `queries_pos` are their position in the sequence.
|
||||||
|
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
||||||
|
// token-batched `PrefillTBatch`.
|
||||||
static HWY_NOINLINE void Transformer(
|
static HWY_NOINLINE void Transformer(
|
||||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
||||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
const ModelWeightsPtrs& weights, Activations& activations,
|
||||||
const LayersOutputFunc& layers_output,
|
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||||
const ActivationsObserverFunc& activations_observer) {
|
|
||||||
const size_t num_queries = queries_token.size();
|
const size_t num_queries = queries_token.size();
|
||||||
HWY_DASSERT(queries_pos.size() == num_queries);
|
HWY_DASSERT(num_queries == queries_pos.size());
|
||||||
HWY_DASSERT(queries_prefix_end.size() == num_queries);
|
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||||
|
|
||||||
if (layers_output) {
|
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
const float token_f = queries_token[query_idx];
|
const float token_f = queries_token[qi];
|
||||||
layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f,
|
runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f,
|
||||||
1);
|
1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
size_t image_token_position = 0;
|
||||||
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
/*pos_in_prompt=*/0, config, weights, activations.x);
|
image_token_position =
|
||||||
|
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
|
||||||
|
/*pos_in_prompt=*/0, config, weights, activations.x,
|
||||||
|
runtime_config.image_tokens, image_token_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||||
|
|
@ -323,21 +332,71 @@ static HWY_NOINLINE void Transformer(
|
||||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
||||||
activations, kv_caches, env);
|
activations, kv_caches, env);
|
||||||
|
|
||||||
if (activations_observer) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
activations_observer(queries_pos, layer_idx, activations);
|
runtime_config.activations_observer(queries_pos, layer_idx, activations);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
// Populates KV cache for the batch queries, one token at a time. Only called
|
||||||
|
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
||||||
|
static HWY_NOINLINE void PrefillQBatch(
|
||||||
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
|
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||||
|
const size_t max_prompt_size, const hwy::Divisor& div_seq_len,
|
||||||
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
|
const ModelWeightsPtrs& weights, Activations& activations,
|
||||||
|
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
||||||
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
|
const size_t num_queries = queries_prompt.size();
|
||||||
|
HWY_DASSERT(num_queries == queries_pos.size());
|
||||||
|
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||||
|
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||||
|
HWY_DASSERT(num_queries == kv_caches.size());
|
||||||
|
|
||||||
if (activations_observer) {
|
hwy::BitSet4096<> prefill_active;
|
||||||
activations_observer(queries_pos, -1, activations);
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
prefill_active.Set(qi);
|
||||||
|
|
||||||
|
HWY_DASSERT(queries_prefix_end[qi] == 0);
|
||||||
|
(void)queries_prefix_end;
|
||||||
}
|
}
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
non_eos = prefill_active;
|
||||||
queries_pos[query_idx] += 1;
|
|
||||||
|
// In autoregressive mode, we don't prefill the last token, hence - 1.
|
||||||
|
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
|
||||||
|
++pos_in_prompt) {
|
||||||
|
// Streams that have already finished prefill no longer interleave/stream.
|
||||||
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
if (pos_in_prompt >= queries_prompt[qi].size() - 1) {
|
||||||
|
prefill_active.Clear(qi);
|
||||||
|
activations.gen_tokens[qi] = config.eos_id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Batch := interleaved tokens, one from each non-EOS query.
|
||||||
|
prefill_active.Foreach([&](size_t qi) {
|
||||||
|
activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt];
|
||||||
|
});
|
||||||
|
|
||||||
|
// One token from each query in the batch. Increments queries_pos.
|
||||||
|
// Do not call DecodeStepT because it computes logits for token
|
||||||
|
// probabilities, which are not required for the prompt tokens.
|
||||||
|
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||||
|
queries_pos, queries_prefix_end, div_seq_len, config,
|
||||||
|
runtime_config, weights, activations, kv_caches, env);
|
||||||
|
|
||||||
|
prefill_active.Foreach([&](size_t qi) {
|
||||||
|
const int token = queries_prompt[qi][pos_in_prompt];
|
||||||
|
// Ignore any user request to stop during prefill.
|
||||||
|
(void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi],
|
||||||
|
token, 0.0f);
|
||||||
|
queries_pos[qi] += 1;
|
||||||
|
});
|
||||||
|
} // pos_in_prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: inline.
|
||||||
void RangeChecks(const ModelConfig& weights_config,
|
void RangeChecks(const ModelConfig& weights_config,
|
||||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
size_t& max_generated_tokens, const size_t prompt_size) {
|
||||||
if (!weights_config.use_local_attention) {
|
if (!weights_config.use_local_attention) {
|
||||||
|
|
@ -350,56 +409,49 @@ void RangeChecks(const ModelConfig& weights_config,
|
||||||
HWY_ASSERT(prompt_size > 0);
|
HWY_ASSERT(prompt_size > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Holds "is at end of stream" state for each query.
|
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
|
||||||
class TokenStreamer {
|
// and updates `non_eos` if the query is at the end of its sequence.
|
||||||
public:
|
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
||||||
TokenStreamer(const RuntimeConfig& runtime_config,
|
const float prob, const ModelConfig& config,
|
||||||
const ModelConfig& model_config)
|
const RuntimeConfig& runtime_config,
|
||||||
: runtime_config_(runtime_config), model_config_(model_config) {}
|
Activations& activations,
|
||||||
|
hwy::BitSet4096<>& non_eos) {
|
||||||
|
HWY_DASSERT(non_eos.Get(qi));
|
||||||
|
|
||||||
// Returns whether the query was already at, or has just reached, the end of
|
// User decided to stop: set next token to primary EOS.
|
||||||
// the stream: either via token == eos_id, or StreamToken returning false.
|
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
|
||||||
bool operator()(size_t query_idx, size_t pos, int token, float prob) {
|
token = config.eos_id;
|
||||||
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
|
|
||||||
|
|
||||||
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
|
|
||||||
model_config_.IsEOS(token)) {
|
|
||||||
is_eos_.Set(query_idx);
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
// Primary or secondary EOS: mark query as EOS.
|
||||||
|
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||||
|
|
||||||
|
activations.gen_tokens[qi] = token;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
// For a batch of queries, runs Transformer, computes logits, samples and
|
||||||
const RuntimeConfig& runtime_config_;
|
// streams the token.
|
||||||
const ModelConfig& model_config_;
|
static void DecodeStepT(
|
||||||
hwy::BitSet4096<> is_eos_;
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
};
|
|
||||||
|
|
||||||
// Runs one decode step for all the queries in the batch. Returns true if all
|
|
||||||
// queries are at <end_of_sentence>.
|
|
||||||
static bool DecodeStepT(
|
|
||||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
|
||||||
const RuntimeConfig& runtime_config, const size_t query_idx_start,
|
|
||||||
const QueriesPromptTokens& queries_prompt,
|
|
||||||
const QueriesMutablePos& queries_mutable_pos,
|
const QueriesMutablePos& queries_mutable_pos,
|
||||||
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
|
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
|
||||||
const size_t vocab_size, const SampleFunc& sample_token,
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
Activations& activations, const KVCaches& kv_caches,
|
const ModelWeightsPtrs& weights, const SampleFunc& sample_token,
|
||||||
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
|
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||||
TimingInfo& timing_info, MatMulEnv& env) {
|
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
// Decode generates one token per query and increments
|
|
||||||
// queries_mutable_pos.
|
|
||||||
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
|
||||||
queries_prefix_end, div_seq_len, config, weights, activations,
|
|
||||||
kv_caches, env, runtime_config.layers_output,
|
|
||||||
runtime_config.activations_observer);
|
|
||||||
// queries_pos are incremented by Transformer.
|
|
||||||
|
|
||||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||||
bool all_queries_eos = true;
|
|
||||||
|
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||||
|
queries_mutable_pos, queries_prefix_end, div_seq_len, config,
|
||||||
|
runtime_config, weights, activations, kv_caches, env);
|
||||||
|
|
||||||
|
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||||
|
|
||||||
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
|
runtime_config.activations_observer(queries_mutable_pos, -1, activations);
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||||
// Compute logits from last layer activations.
|
// Compute logits from last layer activations.
|
||||||
|
|
@ -407,19 +459,17 @@ static bool DecodeStepT(
|
||||||
/*add=*/nullptr, env, activations.logits);
|
/*add=*/nullptr, env, activations.logits);
|
||||||
}
|
}
|
||||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
non_eos.Foreach([&](size_t qi) {
|
||||||
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
|
float* HWY_RESTRICT logits = activations.logits.Row(qi);
|
||||||
MaybeLogitsSoftCap(config.final_cap, logits, vocab_size);
|
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size);
|
||||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||||
timing_info.NotifyGenerated();
|
timing_info.NotifyGenerated();
|
||||||
|
|
||||||
const bool is_eos =
|
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token,
|
||||||
token_streamer(query_idx_start + query_idx,
|
tp.prob, config, runtime_config, activations, non_eos);
|
||||||
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
|
||||||
all_queries_eos &= is_eos;
|
if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1;
|
||||||
gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token;
|
});
|
||||||
}
|
|
||||||
return all_queries_eos;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE SampleFunc
|
static HWY_INLINE SampleFunc
|
||||||
|
|
@ -445,33 +495,26 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the min and max number of tokens for all queries.
|
|
||||||
static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
|
|
||||||
size_t max_prompt_size = 0;
|
|
||||||
for (size_t i = 0; i < queries_prompt.size(); ++i) {
|
|
||||||
max_prompt_size = HWY_MAX(max_prompt_size, queries_prompt[i].size());
|
|
||||||
}
|
|
||||||
return max_prompt_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generates one continuation for each query in `queries_prompt`, which is one
|
// Generates one continuation for each query in `queries_prompt`, which is one
|
||||||
// qbatch whose size is at most the `batch_size` passed to
|
// qbatch whose size is at most the `batch_size` passed to `activations` ctor.
|
||||||
// `activations.Allocate`.
|
|
||||||
//
|
//
|
||||||
// `queries_pos` stores the KV cache position for each query. In the first turn
|
// `queries_pos` stores the KV cache position for each query. In the first turn
|
||||||
// of a chat, pos = 0; we increment each query's position after each token.
|
// of a chat, pos = 0; we increment each query's position after each token.
|
||||||
//
|
//
|
||||||
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
// `query_idx_start` is the query_idx of the first query in the batch, so that
|
||||||
// `StreamFunc` gets the global query index, not relative to the batch.
|
// `StreamFunc` gets the global query index, not relative to the batch.
|
||||||
//
|
|
||||||
// `kv_caches` is for the batch, size must match `queries_prompt`.
|
|
||||||
static void GenerateT(
|
static void GenerateT(
|
||||||
const ModelConfig& config, const ModelWeightsPtrs& weights,
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
const RuntimeConfig& runtime_config, const size_t query_idx_start,
|
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end,
|
||||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in,
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
const QueriesPos& queries_prefix_end, Activations& activations,
|
const ModelWeightsPtrs& weights, Activations& activations,
|
||||||
const KVCaches& kv_caches, TimingInfo& timing_info, MatMulEnv& env) {
|
const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
|
const size_t num_queries = queries_prompt.size();
|
||||||
|
HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`.
|
||||||
|
HWY_ASSERT(num_queries == queries_pos_in.size());
|
||||||
|
HWY_ASSERT(num_queries == queries_prefix_end.size());
|
||||||
|
HWY_ASSERT(num_queries <= activations.x.Rows());
|
||||||
|
HWY_ASSERT(num_queries == kv_caches.size());
|
||||||
|
|
||||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||||
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
for (size_t i = 0; i < kv_caches.size(); ++i) {
|
||||||
|
|
@ -486,75 +529,79 @@ static void GenerateT(
|
||||||
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||||
queries_pos_copy.size());
|
queries_pos_copy.size());
|
||||||
|
|
||||||
|
size_t max_prompt_size = 0;
|
||||||
|
bool all_prefix_end_are_zero = true;
|
||||||
|
size_t prefill_tokens = 0;
|
||||||
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
const PromptTokens& prompt = queries_prompt[qi];
|
||||||
|
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||||
|
|
||||||
|
// Prefill stops before size - 1 because the last prompt token is the
|
||||||
|
// first input token for generation.
|
||||||
|
prefill_tokens += prompt.size() - 1;
|
||||||
|
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
|
||||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
|
||||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||||
|
|
||||||
|
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t num_queries = queries_prompt.size();
|
|
||||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
|
||||||
HWY_ASSERT(num_queries <= activations.x.Rows());
|
|
||||||
HWY_ASSERT(queries_pos_in.size() == num_queries);
|
|
||||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
|
||||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
|
||||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||||
RangeChecks(config, max_generated_tokens, max_prompt_size);
|
// qi loops anyway.
|
||||||
|
hwy::BitSet4096<> non_eos;
|
||||||
|
|
||||||
|
timing_info.prefill_start = hwy::platform::Now();
|
||||||
|
// Batch over the larger of prompt length, or queries.
|
||||||
|
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
|
||||||
|
activations.SetBatchSize(num_queries); // required before PrefillQBatch
|
||||||
|
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||||
|
queries_prefix_end, max_prompt_size, div_seq_len, config,
|
||||||
|
runtime_config, weights, activations, kv_caches, env,
|
||||||
|
non_eos);
|
||||||
|
} else {
|
||||||
|
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||||
|
queries_prefix_end, div_seq_len, config, runtime_config,
|
||||||
|
weights, activations, kv_caches, env, non_eos);
|
||||||
|
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
|
||||||
|
}
|
||||||
|
HWY_DASSERT(num_queries == non_eos.Count());
|
||||||
|
timing_info.NotifyPrefill(prefill_tokens);
|
||||||
|
// queries_pos have been incremented by Prefill.
|
||||||
|
|
||||||
|
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||||
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
const size_t last_token_pos_in_prompt =
|
||||||
|
queries_mutable_pos[qi] - queries_pos_in[qi];
|
||||||
|
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi],
|
||||||
|
queries_prompt[qi][last_token_pos_in_prompt], 0.0f,
|
||||||
|
config, runtime_config, activations, non_eos);
|
||||||
|
// No incrementing queries_mutable_pos[qi].
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
|
RangeChecks(config, max_gen_steps, max_prompt_size);
|
||||||
|
|
||||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||||
|
|
||||||
// Prefill stops before min_prompt_size - 1 because the last prompt
|
|
||||||
// token is the first input token for generation.
|
|
||||||
timing_info.prefill_start = hwy::platform::Now();
|
|
||||||
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
|
|
||||||
Prefill(query_idx_start, queries_prompt, queries_mutable_pos,
|
|
||||||
queries_prefix_end, div_seq_len, config, runtime_config, weights,
|
|
||||||
activations, kv_caches, env);
|
|
||||||
// Compute the number of tokens that were prefilled and notify timing_info.
|
|
||||||
size_t prefilled_tokens = 0;
|
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
|
||||||
prefilled_tokens += queries_prompt[qi].size() - 1;
|
|
||||||
}
|
|
||||||
timing_info.NotifyPrefill(prefilled_tokens);
|
|
||||||
// queries_pos are incremented by Prefill.
|
|
||||||
activations.SetBatchSize(num_queries);
|
|
||||||
|
|
||||||
// Storage for the last generated token from each query, passed to the next
|
|
||||||
// Transformer() call.
|
|
||||||
std::vector<int> gen_tokens(num_queries);
|
|
||||||
|
|
||||||
// Stream the last prompt token from each query and fill gen_tokens.
|
|
||||||
TokenStreamer token_streamer(runtime_config, config);
|
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
|
||||||
size_t last_token_pos_in_prompt =
|
|
||||||
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
|
|
||||||
gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt];
|
|
||||||
(void)token_streamer(query_idx_start + query_idx,
|
|
||||||
queries_mutable_pos[query_idx], gen_tokens[query_idx],
|
|
||||||
0.0f);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
const size_t vocab_size = config.vocab_size;
|
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
bool all_queries_eos =
|
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||||
DecodeStepT(config, weights, runtime_config, query_idx_start,
|
queries_prefix_end, div_seq_len, config, runtime_config,
|
||||||
queries_prompt, queries_mutable_pos, queries_prefix_end,
|
weights, sample_token, activations, kv_caches, env, non_eos,
|
||||||
div_seq_len, vocab_size, sample_token, activations,
|
timing_info);
|
||||||
kv_caches, token_streamer, gen_tokens, timing_info, env);
|
}
|
||||||
|
|
||||||
if (all_queries_eos) break;
|
|
||||||
} // foreach token to generate
|
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
const ModelWeightsPtrs& weights, KVCache& kv_cache,
|
||||||
KVCache& kv_cache, MatMulEnv& env,
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
TimingInfo& timing_info) {
|
|
||||||
constexpr size_t kNumQueries = 1;
|
constexpr size_t kNumQueries = 1;
|
||||||
const size_t qbatch_start = 0;
|
const size_t qbatch_start = 0;
|
||||||
|
|
||||||
|
|
@ -568,25 +615,27 @@ void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||||
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
|
||||||
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
const KVCaches kv_caches{&kv_cache, kNumQueries};
|
||||||
|
|
||||||
GenerateT(config, weights, runtime_config, qbatch_start, queries_prompt,
|
GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end,
|
||||||
queries_pos, queries_prefix_end, activations, kv_caches,
|
config, runtime_config, weights, activations, kv_caches, env,
|
||||||
timing_info, env);
|
timing_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
|
||||||
const RuntimeConfig& runtime_config,
|
// queries, and calls `GenerateT` on each batch.
|
||||||
const QueriesPromptTokens& queries_prompt,
|
void GenerateBatchT(const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesPos& queries_pos,
|
const QueriesPos& queries_pos,
|
||||||
const QueriesPos& queries_prefix_end,
|
const QueriesPos& queries_prefix_end,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env,
|
const ModelConfig& config,
|
||||||
TimingInfo& timing_info) {
|
const RuntimeConfig& runtime_config,
|
||||||
|
const ModelWeightsPtrs& weights, const KVCaches& kv_caches,
|
||||||
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_ASSERT(queries_pos.size() == num_queries);
|
HWY_ASSERT(queries_pos.size() == num_queries);
|
||||||
HWY_ASSERT(kv_caches.size() >= num_queries);
|
HWY_ASSERT(kv_caches.size() >= num_queries);
|
||||||
|
|
||||||
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
|
||||||
const size_t max_batch_size =
|
const size_t max_batch_size =
|
||||||
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
|
||||||
|
|
||||||
Activations activations(config, max_batch_size, env.row_ptrs);
|
Activations activations(config, max_batch_size, env.row_ptrs);
|
||||||
|
|
||||||
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
for (size_t qbatch_start = 0; qbatch_start < num_queries;
|
||||||
|
|
@ -600,17 +649,16 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
|
||||||
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
|
||||||
qbatch_size);
|
qbatch_size);
|
||||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||||
GenerateT(config, weights, runtime_config, qbatch_start, qbatch_prompts,
|
GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end,
|
||||||
qbatch_pos, qbatch_prefix_end, activations, qbatch_kv,
|
config, runtime_config, weights, activations, qbatch_kv, env,
|
||||||
timing_info, env);
|
timing_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateImageTokensT(const ModelConfig& config,
|
void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const ModelWeightsPtrs& weights,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const Image& image, ImageTokens& image_tokens,
|
const ModelWeightsPtrs& weights, const Image& image,
|
||||||
MatMulEnv& env) {
|
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||||
if (config.vit_config.layer_configs.empty()) {
|
if (config.vit_config.layer_configs.empty()) {
|
||||||
HWY_ABORT("Model does not support generating image tokens.");
|
HWY_ABORT("Model does not support generating image tokens.");
|
||||||
}
|
}
|
||||||
|
|
@ -667,9 +715,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
KVCache& kv_cache, TimingInfo& timing_info) const {
|
KVCache& kv_cache, TimingInfo& timing_info) const {
|
||||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_,
|
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
|
||||||
runtime_config, prompt, pos, prefix_end,
|
model_.Config(), runtime_config,
|
||||||
kv_cache, env_, timing_info);
|
weights_, kv_cache, env_, timing_info);
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
@ -681,19 +729,19 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
const KVCaches& kv_caches,
|
const KVCaches& kv_caches,
|
||||||
TimingInfo& timing_info) const {
|
TimingInfo& timing_info) const {
|
||||||
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
|
||||||
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
|
QueriesPos queries_prefix_end_or_zeros = queries_prefix_end;
|
||||||
std::vector<size_t> prefix_end_vec;
|
std::vector<size_t> prefix_end_vec;
|
||||||
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
|
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
|
||||||
prefix_end_vec.resize(queries_prompt.size(), 0);
|
prefix_end_vec.resize(queries_prompt.size(), 0);
|
||||||
mutable_queries_prefix_end =
|
queries_prefix_end_or_zeros =
|
||||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
|
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
|
||||||
model_.Config(), weights_, runtime_config, queries_prompt, queries_pos,
|
queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(),
|
||||||
mutable_queries_prefix_end, kv_caches, env_, timing_info);
|
runtime_config, weights_, kv_caches, env_, timing_info);
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
@ -704,7 +752,7 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
|
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
|
||||||
model_.Config(), weights_, runtime_config, image, image_tokens, env_);
|
model_.Config(), runtime_config, weights_, image, image_tokens, env_);
|
||||||
|
|
||||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,10 +26,7 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// The tokenizer's end of sentence and beginning of sentence token ids.
|
constexpr int BOS_ID = 2; // beginning of sequence
|
||||||
constexpr int EOS_ID = 1;
|
|
||||||
constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3
|
|
||||||
constexpr int BOS_ID = 2;
|
|
||||||
|
|
||||||
// To avoid the complexity of storing the tokenizer into testdata/ or
|
// To avoid the complexity of storing the tokenizer into testdata/ or
|
||||||
// downloading from gs://, while still always writing a blob for the tokenizer,
|
// downloading from gs://, while still always writing a blob for the tokenizer,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue