New token validity assertions, improve prompt truncation warning

PiperOrigin-RevId: 627376194
This commit is contained in:
Paul Chang 2024-04-23 07:05:16 -07:00 committed by Copybara-Service
parent 3bf22abb22
commit e8d29792ac
1 changed files with 8 additions and 4 deletions

View File

@ -890,6 +890,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
const int token = tokens[token_idx];
HWY_ASSERT(token >= 0);
HWY_ASSERT(token < TConfig::kVocabSize);
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
@ -1009,10 +1011,11 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
if (!TConfig::kUseLocalAttention) {
if (prompt_size + max_generated_tokens > max_tokens) {
fprintf(stderr,
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
"%d, truncating.\n",
prompt_size, max_generated_tokens, TConfig::kSeqLen);
prompt_size = max_tokens - max_generated_tokens;
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
"max_tokens %zu, truncating to ",
prompt_size, max_generated_tokens, max_tokens);
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
fprintf(stderr, "%zu\n", prompt_size);
}
}
}
@ -1040,6 +1043,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
max_tokens);
return;
}
HWY_ASSERT(prompt_size > 0);
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
//