diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 96e8a63..8476b5c 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -987,7 +987,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, TokenStreamer token_streamer(runtime_config); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { gen_tokens[query_idx] = prompts[query_idx][prefill_per_query]; - (void)token_streamer(query_idx_start + query_idx, prefill_per_query, + (void)token_streamer(query_idx_start + query_idx, + pos[query_idx] + prefill_per_query, gen_tokens[query_idx], 0.0f); } @@ -1020,9 +1021,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const int token = sample_token(logits, kVocabSize); timing_info.NotifyGenerated(prefill_start, gen_start); - const bool is_eos = token_streamer(query_idx_start + query_idx, - prefill_per_query + 1 + gen_per_query, - token, logits[token]); + const bool is_eos = + token_streamer(query_idx_start + query_idx, + pos[query_idx] + prefill_per_query + 1 + gen_per_query, + token, logits[token]); all_queries_eos &= is_eos; gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; }