mirror of https://github.com/google/gemma.cpp.git
New token validity assertions, improve prompt truncation warning
PiperOrigin-RevId: 627376194
This commit is contained in:
parent
3bf22abb22
commit
e8d29792ac
|
|
@ -890,6 +890,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
|
||||||
const int token = tokens[token_idx];
|
const int token = tokens[token_idx];
|
||||||
|
HWY_ASSERT(token >= 0);
|
||||||
|
HWY_ASSERT(token < TConfig::kVocabSize);
|
||||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||||
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
|
||||||
|
|
@ -1009,10 +1011,11 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
if (!TConfig::kUseLocalAttention) {
|
if (!TConfig::kUseLocalAttention) {
|
||||||
if (prompt_size + max_generated_tokens > max_tokens) {
|
if (prompt_size + max_generated_tokens > max_tokens) {
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen "
|
"WARNING: prompt_size %zu + max_generated_tokens %zu > "
|
||||||
"%d, truncating.\n",
|
"max_tokens %zu, truncating to ",
|
||||||
prompt_size, max_generated_tokens, TConfig::kSeqLen);
|
prompt_size, max_generated_tokens, max_tokens);
|
||||||
prompt_size = max_tokens - max_generated_tokens;
|
prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens);
|
||||||
|
fprintf(stderr, "%zu\n", prompt_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1040,6 +1043,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
max_tokens);
|
max_tokens);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
HWY_ASSERT(prompt_size > 0);
|
||||||
|
|
||||||
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
// pos indexes the KV cache. In the first turn of a chat, pos = 0.
|
||||||
//
|
//
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue