Fix off-by-one errors in generation code and token streaming callback.

In the generation code we were feeding the last token of the prompt
twice through the transformer. The new version fixes that and also
works in the case where Prefill is completely disabled.
This commit is contained in:
Zoltan Szabadka 2024-04-04 07:50:03 +00:00
parent ede337f876
commit 71ead04afb
2 changed files with 15 additions and 5 deletions

View File

@ -666,12 +666,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data();
if (pos_offset >= prompt_size) {
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding");
// Generation phase
MatVec<kVocabSize, TConfig::kModelDim>(
@ -681,10 +685,15 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
}
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
}
if (token == EOS_ID) {
if (verbosity >= 2) {
const double gen_end = hwy::platform::Now();

3
run.cc
View File

@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
verbosity](int token, float) {
++abs_pos;
++current_pos;
if (current_pos < prompt_size) {
// <= since position is incremented before
if (current_pos <= prompt_size) {
std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) {
if (!args.multiturn) {