mirror of https://github.com/google/gemma.cpp.git
Fix pos assertions, refs #665
Ensure the streaming func pos matches the number of calls. Add two arguments that control pos+1 and pos+=1 behavior. Also cleanup/add comments. run: use batch_stream_func, add assert, higher verbosity for MM autotune output PiperOrigin-RevId: 799511163
This commit is contained in:
parent
9bf0fe4e37
commit
ed2f0bd1b0
|
|
@ -138,6 +138,10 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
// Reset the `response` string here, then check that the model actually has
|
// Reset the `response` string here, then check that the model actually has
|
||||||
// access to the previous turn by asking to reproduce.
|
// access to the previous turn by asking to reproduce.
|
||||||
response.clear();
|
response.clear();
|
||||||
|
// -1 because our prefill does not generate KVs for the last token. Do not
|
||||||
|
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
|
||||||
|
HWY_ASSERT(abs_pos > 0);
|
||||||
|
--abs_pos;
|
||||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||||
s_env->MutableEnv(), timing_info);
|
s_env->MutableEnv(), timing_info);
|
||||||
fprintf(stderr, "decoded: '%s'\n", response.c_str());
|
fprintf(stderr, "decoded: '%s'\n", response.c_str());
|
||||||
|
|
|
||||||
115
gemma/gemma.cc
115
gemma/gemma.cc
|
|
@ -127,9 +127,10 @@ static float EmbeddingScaling(size_t model_dim) {
|
||||||
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
|
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
|
||||||
}
|
}
|
||||||
|
|
||||||
// `batch_idx` indicates which row of `x` to write to.
|
// `x_row` indicates which row of `x` to write to.
|
||||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
// `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not
|
||||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
// the start of the batch, because this is called for batches of tokens in
|
||||||
|
// prefill, but batches of queries in decode.
|
||||||
//
|
//
|
||||||
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
|
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
|
||||||
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
||||||
|
|
@ -137,7 +138,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
||||||
// calling application.
|
// calling application.
|
||||||
// Returns new image_token_position.
|
// Returns new image_token_position.
|
||||||
static HWY_NOINLINE size_t
|
static HWY_NOINLINE size_t
|
||||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt,
|
||||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
MatStorageT<float>& x, ThreadingContext& ctx,
|
MatStorageT<float>& x, ThreadingContext& ctx,
|
||||||
const ImageTokens* image_tokens = nullptr,
|
const ImageTokens* image_tokens = nullptr,
|
||||||
|
|
@ -146,14 +147,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||||
image_tokens != nullptr && token == -2 &&
|
image_tokens != nullptr && token == -2 &&
|
||||||
image_token_position < image_tokens->Rows()) {
|
image_token_position < image_tokens->Rows()) {
|
||||||
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi),
|
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row),
|
||||||
x.Cols() * x.ElementBytes());
|
x.Cols() * x.ElementBytes());
|
||||||
return image_token_position + 1;
|
return image_token_position + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||||
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
|
||||||
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi),
|
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row),
|
||||||
x.Cols() * x.ElementBytes());
|
x.Cols() * x.ElementBytes());
|
||||||
return image_token_position;
|
return image_token_position;
|
||||||
}
|
}
|
||||||
|
|
@ -174,14 +175,14 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
const auto embedding_span =
|
const auto embedding_span =
|
||||||
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi),
|
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row),
|
||||||
model_dim);
|
model_dim);
|
||||||
MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim,
|
MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim,
|
||||||
ctx.profiler, worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (model_config.absolute_pe) {
|
if (model_config.absolute_pe) {
|
||||||
AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos);
|
AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos);
|
||||||
}
|
}
|
||||||
return image_token_position;
|
return image_token_position;
|
||||||
}
|
}
|
||||||
|
|
@ -249,24 +250,12 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
const size_t pos = qbatch_1.Pos(0) + ti;
|
||||||
const size_t pos_in_prompt = tbatch_start + ti;
|
const size_t pos_in_prompt = tbatch_start + ti;
|
||||||
|
HWY_DASSERT(pos_in_prompt < prompt_size);
|
||||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
||||||
image_token_position = EmbedMMToken(
|
image_token_position = EmbedMMToken(
|
||||||
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
token, ti, pos, pos_in_prompt, config, weights, activations.x,
|
||||||
env.ctx, runtime_config.image_tokens, image_token_position);
|
env.ctx, runtime_config.image_tokens, image_token_position);
|
||||||
}
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
|
|
||||||
// Transformer with one batch of tokens from a single query.
|
|
||||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
|
||||||
++layer_idx) {
|
|
||||||
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
|
||||||
activations, qbatch_1, env);
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
|
||||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
|
||||||
const size_t pos = qbatch_1.Pos(0) + ti;
|
|
||||||
const size_t pos_in_prompt = tbatch_start + ti;
|
|
||||||
const int token = qbatch_1.Prompt(0)[pos_in_prompt];
|
|
||||||
if (pos_in_prompt < prompt_size - 1) {
|
if (pos_in_prompt < prompt_size - 1) {
|
||||||
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
|
runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -276,6 +265,14 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transformer with one batch of tokens from a single query. No need to
|
||||||
|
// set `PrevToken` because we already did the embedding above.
|
||||||
|
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||||
|
++layer_idx) {
|
||||||
|
TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx),
|
||||||
|
activations, qbatch_1, env);
|
||||||
|
}
|
||||||
|
|
||||||
qbatch_1.MutablePos(0) += tbatch_size;
|
qbatch_1.MutablePos(0) += tbatch_size;
|
||||||
} // for tbatch_start
|
} // for tbatch_start
|
||||||
if (attend_to_last_token) {
|
if (attend_to_last_token) {
|
||||||
|
|
@ -291,8 +288,8 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
// Embeds PrevToken (one from each query) and calls each TransformerLayer.
|
||||||
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
|
// Called by query-batched `PrefillQBatch` and `GenerateT`, but not the
|
||||||
// token-batched `PrefillTBatch`.
|
// token-batched `PrefillTBatch`, which supports image embedding.
|
||||||
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const WeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
|
|
@ -324,8 +321,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populates KV cache for the batch queries, one token at a time. Only called
|
// Populates KV cache for the batch queries, one token at a time.
|
||||||
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
|
|
||||||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
const ModelConfig& config,
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
|
|
@ -337,6 +333,8 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
|
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
non_eos.Set(qi);
|
non_eos.Set(qi);
|
||||||
|
|
||||||
|
// Should only be called for autoregressive (non-prefix-LM) prefill.
|
||||||
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
|
HWY_DASSERT(qbatch.PrefixEnd(qi) == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -358,7 +356,7 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
// The input (PrevToken) is one token from each query in the batch.
|
// The input (PrevToken) is one token from each query in the batch.
|
||||||
// Do not call DecodeStepT because it computes logits for token
|
// Do not call `SampleAndStream` because it computes logits for token
|
||||||
// probabilities, which are not required for the prompt tokens.
|
// probabilities, which are not required for the prompt tokens.
|
||||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
}
|
}
|
||||||
|
|
@ -369,42 +367,40 @@ static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent
|
||||||
// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the
|
// `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the
|
||||||
// query is at the end of its sequence.
|
// query is at the end of its sequence.
|
||||||
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||||
const ModelConfig& config,
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
QBatch& qbatch, hwy::BitSet4096<>& non_eos) {
|
QBatch& qbatch, bool pos_plus_1, bool update_pos,
|
||||||
|
hwy::BitSet4096<>& non_eos) {
|
||||||
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called.
|
||||||
|
|
||||||
if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi),
|
const size_t pos = qbatch.Pos(qi) + (pos_plus_1 ? 1 : 0);
|
||||||
qbatch.Pos(qi), token, prob))) {
|
if (HWY_UNLIKELY(
|
||||||
|
!runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) {
|
||||||
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
// User decided to stop: set token to primary EOS to trigger IsEOS below.
|
||||||
token = config.eos_id;
|
token = config.eos_id;
|
||||||
HWY_DASSERT(config.IsEOS(token));
|
HWY_DASSERT(config.IsEOS(token));
|
||||||
}
|
}
|
||||||
|
|
||||||
qbatch.PrevToken(qi) = token;
|
qbatch.PrevToken(qi) = token;
|
||||||
qbatch.MutablePos(qi) += 1;
|
qbatch.MutablePos(qi) += update_pos ? 1 : 0;
|
||||||
|
|
||||||
// Primary or secondary EOS: mark query as EOS, but still increment (for
|
// Primary or secondary EOS: mark query as EOS, but still increment (for
|
||||||
// multi-turn, we should still keep the prior EOS).
|
// multi-turn, we should still keep the prior EOS).
|
||||||
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For a batch of queries, runs Transformer, computes logits, samples and
|
// Must be called after Transformer: either after prefill, or during decode.
|
||||||
// streams the token.
|
// Computes logits, samples and streams the token.
|
||||||
static void DecodeStepT(const ModelConfig& config,
|
static void SampleAndStream(
|
||||||
const RuntimeConfig& runtime_config,
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
const WeightsPtrs& weights,
|
const WeightsPtrs& weights, const SampleFunc& sample_token,
|
||||||
const SampleFunc& sample_token,
|
Activations& activations, QBatch& qbatch, bool update_pos, MatMulEnv& env,
|
||||||
Activations& activations, QBatch& qbatch,
|
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
|
||||||
TimingInfo& timing_info) {
|
|
||||||
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
HWY_DASSERT(qbatch.Size() == activations.x.Rows());
|
||||||
|
|
||||||
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
|
||||||
|
|
||||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
RMSNormInplaceBatched(weights.final_norm_scale, activations.x, env.ctx);
|
||||||
|
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
|
|
@ -427,8 +423,12 @@ static void DecodeStepT(const ModelConfig& config,
|
||||||
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
const TokenAndProb tp = sample_token(logits, config.vocab_size);
|
||||||
timing_info.NotifyGenerated();
|
timing_info.NotifyGenerated();
|
||||||
|
|
||||||
|
// We streamed all prefill tokens, but pos is still one behind because we
|
||||||
|
// started generation at pos = prompt.size() - 1. We want the pos argument
|
||||||
|
// to match the number of calls to `StreamToken`, as expected by the caller.
|
||||||
|
const bool pos_plus_1 = true;
|
||||||
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch,
|
||||||
non_eos);
|
pos_plus_1, update_pos, non_eos);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -476,15 +476,16 @@ static void GenerateT(const ModelConfig& config,
|
||||||
const size_t seq_len = qbatch.KV(0).SeqLen();
|
const size_t seq_len = qbatch.KV(0).SeqLen();
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const PromptTokens& prompt = qbatch.Prompt(qi);
|
const PromptTokens& prompt = qbatch.Prompt(qi);
|
||||||
|
// Sanity check: prompts should not be empty. Note that multi-turn prompts
|
||||||
|
// start with <end_of_turn>.
|
||||||
|
HWY_ASSERT(prompt.size() != 0);
|
||||||
|
|
||||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||||
|
|
||||||
// Prefill stops before size - 1 because the last prompt token is the
|
// Prefill stops before size - 1 because the last prompt token is the
|
||||||
// first input token for generation.
|
// first input token for generation.
|
||||||
total_prefill_tokens += prompt.size() - 1;
|
total_prefill_tokens += prompt.size() - 1;
|
||||||
|
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
|
||||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
|
||||||
|
|
||||||
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
|
all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0;
|
||||||
|
|
||||||
// We use a single divisor, so all sequence lengths must be the same.
|
// We use a single divisor, so all sequence lengths must be the same.
|
||||||
|
|
@ -518,14 +519,12 @@ static void GenerateT(const ModelConfig& config,
|
||||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||||
|
const bool pos_plus_1 = false; // during prefill, pos is still correct.
|
||||||
|
// In autoregressive mode, we have not prefilled the last token, so do
|
||||||
|
// not advance.
|
||||||
|
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
||||||
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
|
||||||
runtime_config, qbatch, non_eos);
|
runtime_config, qbatch, pos_plus_1, update_pos, non_eos);
|
||||||
// StreamAndUpdateEOS() sets the stream position one token too far in
|
|
||||||
// autoregressive mode.
|
|
||||||
const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
|
||||||
if (!attend_to_last_token) {
|
|
||||||
qbatch.MutablePos(qi) -= 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
|
|
@ -540,8 +539,10 @@ static void GenerateT(const ModelConfig& config,
|
||||||
{
|
{
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
DecodeStepT(config, runtime_config, weights, sample_token, activations,
|
Transformer(config, runtime_config, weights, activations, qbatch, env);
|
||||||
qbatch, env, non_eos, timing_info);
|
SampleAndStream(config, runtime_config, weights, sample_token,
|
||||||
|
activations, qbatch, /*update_pos=*/true, env, non_eos,
|
||||||
|
timing_info);
|
||||||
}
|
}
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
16
gemma/run.cc
16
gemma/run.cc
|
|
@ -132,7 +132,12 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
}
|
}
|
||||||
|
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
auto stream_token = [&](int token, float) {
|
auto batch_stream_token = [&](size_t query_idx, size_t pos, int token,
|
||||||
|
float) {
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
|
|
||||||
|
HWY_ASSERT(pos == abs_pos);
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||||
|
|
@ -148,8 +153,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
std::string token_text;
|
|
||||||
HWY_ASSERT(gemma.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
|
||||||
if (first_response_token) {
|
if (first_response_token) {
|
||||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
|
|
@ -187,7 +190,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
.verbosity = inference.verbosity,
|
.verbosity = inference.verbosity,
|
||||||
.stream_token = stream_token,
|
.batch_stream_token = batch_stream_token,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
inference.CopyTo(runtime_config);
|
inference.CopyTo(runtime_config);
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
|
|
@ -223,6 +226,9 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||||
}
|
}
|
||||||
|
// -1 because our prefill does not generate KVs for the last token. Do not
|
||||||
|
// just pass abs_pos - 1 because our callback checks pos == abs_pos.
|
||||||
|
if (abs_pos > 0) --abs_pos;
|
||||||
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
|
gemma.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, env,
|
||||||
timing_info);
|
timing_info);
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
@ -255,7 +261,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
|
|
||||||
ThreadingContext ctx(threading);
|
ThreadingContext ctx(threading);
|
||||||
MatMulEnv env(ctx);
|
MatMulEnv env(ctx);
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 3) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, ctx);
|
const Gemma gemma(loader, inference, ctx);
|
||||||
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue