mirror of https://github.com/google/gemma.cpp.git
New token validity assertions, improve prompt truncation warning
PiperOrigin-RevId: 627376194
This commit is contained in:
parent
3bf22abb22
commit
e8d29792ac
|
|
@ -890,6 +890,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
pool.Run(
|
||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||
const int token = tokens[token_idx];
|
||||
HWY_ASSERT(token >= 0);
|
||||
HWY_ASSERT(token < TConfig::kVocabSize);
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||
|
|
@ -1009,10 +1011,11 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
if (!TConfig::kUseLocalAttention) {
|
||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||
fprintf(stderr,
|
||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
||||
"%d, truncating.\n",
|
||||
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
||||
prompt_size = max_tokens - max_generated_tokens;
|
||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
|
||||
"max_tokens %zu, truncating to ",
|
||||
prompt_size, max_generated_tokens, max_tokens);
|
||||
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
|
||||
fprintf(stderr, "%zu\n", prompt_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1040,6 +1043,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
max_tokens);
|
||||
return;
|
||||
}
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
|
||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||
//
|
||||
|
|
|
|||
Loading…
Reference in New Issue