diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bef03c4..ab43b69 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -890,6 +890,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { const int token = tokens[token_idx]; + HWY_ASSERT(token >= 0); + HWY_ASSERT(token < TConfig::kVocabSize); Decompress(weights.embedder_input_embedding, token * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim); MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, @@ -1009,10 +1011,11 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, if (!TConfig::kUseLocalAttention) { if (prompt_size + max_generated_tokens > max_tokens) { fprintf(stderr, - "WARNING: prompt_size %zu + max_generated_tokens %zu > kSeqLen " - "%d, truncating.\n", - prompt_size, max_generated_tokens, TConfig::kSeqLen); - prompt_size = max_tokens - max_generated_tokens; + "WARNING: prompt_size %zu + max_generated_tokens %zu > " + "max_tokens %zu, truncating to ", + prompt_size, max_generated_tokens, max_tokens); + prompt_size = std::min(prompt_size, max_tokens - max_generated_tokens); + fprintf(stderr, "%zu\n", prompt_size); } } } @@ -1040,6 +1043,7 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, max_tokens); return; } + HWY_ASSERT(prompt_size > 0); // pos indexes the KV cache. In the first turn of a chat, pos = 0. //