From 8c634f6486a3f0c33cbb5357b311cd6222a62429 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Mon, 12 Aug 2024 02:33:10 +0800 Subject: [PATCH] Fix the position calculation issue in the generation phase --- gemma/gemma-inl.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 96e8a63..8476b5c 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -987,7 +987,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, TokenStreamer token_streamer(runtime_config); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { gen_tokens[query_idx] = prompts[query_idx][prefill_per_query]; - (void)token_streamer(query_idx_start + query_idx, prefill_per_query, + (void)token_streamer(query_idx_start + query_idx, + pos[query_idx] + prefill_per_query, gen_tokens[query_idx], 0.0f); } @@ -1020,9 +1021,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const int token = sample_token(logits, kVocabSize); timing_info.NotifyGenerated(prefill_start, gen_start); - const bool is_eos = token_streamer(query_idx_start + query_idx, - prefill_per_query + 1 + gen_per_query, - token, logits[token]); + const bool is_eos = + token_streamer(query_idx_start + query_idx, + pos[query_idx] + prefill_per_query + 1 + gen_per_query, + token, logits[token]); all_queries_eos &= is_eos; gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; }