Fix off-by-one errors in generation code and token streaming callback.

In the generation code we were feeding the last token of the prompt
twice through the transformer. The new version fixes that and also
works in the case where Prefill is completely disabled.
This commit is contained in:
Zoltan Szabadka 2024-04-04 07:50:03 +00:00
parent ede337f876
commit 71ead04afb
2 changed files with 15 additions and 5 deletions

View File

@ -666,12 +666,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_gen_start = pos_offset; size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset); int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0; for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
if (pos_offset >= prompt_size) { // The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Generation phase // Generation phase
MatVec<kVocabSize, TConfig::kModelDim>( MatVec<kVocabSize, TConfig::kModelDim>(
@ -681,9 +685,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize, token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token); gen, temperature, accept_token);
} if (!stream_token(token, activations.logits[token])) {
if (!stream_token(token, activations.logits[token])) { token = EOS_ID;
token = EOS_ID; }
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
} }
if (token == EOS_ID) { if (token == EOS_ID) {
if (verbosity >= 2) { if (verbosity >= 2) {

3
run.cc
View File

@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
verbosity](int token, float) { verbosity](int token, float) {
++abs_pos; ++abs_pos;
++current_pos; ++current_pos;
if (current_pos < prompt_size) { // <= since position is incremented before
if (current_pos <= prompt_size) {
std::cerr << "." << std::flush; std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) { } else if (token == gcpp::EOS_ID) {
if (!args.multiturn) { if (!args.multiturn) {