Fix the position calculation issue in the generation phase

This commit is contained in:
RangerUFO 2024-08-12 02:33:10 +08:00 committed by Jan Wassenberg
parent ea72575e56
commit 8c634f6486
1 changed files with 6 additions and 4 deletions

View File

@ -987,7 +987,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
TokenStreamer token_streamer(runtime_config); TokenStreamer token_streamer(runtime_config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query]; gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
(void)token_streamer(query_idx_start + query_idx, prefill_per_query, (void)token_streamer(query_idx_start + query_idx,
pos[query_idx] + prefill_per_query,
gen_tokens[query_idx], 0.0f); gen_tokens[query_idx], 0.0f);
} }
@ -1020,8 +1021,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const int token = sample_token(logits, kVocabSize); const int token = sample_token(logits, kVocabSize);
timing_info.NotifyGenerated(prefill_start, gen_start); timing_info.NotifyGenerated(prefill_start, gen_start);
const bool is_eos = token_streamer(query_idx_start + query_idx, const bool is_eos =
prefill_per_query + 1 + gen_per_query, token_streamer(query_idx_start + query_idx,
pos[query_idx] + prefill_per_query + 1 + gen_per_query,
token, logits[token]); token, logits[token]);
all_queries_eos &= is_eos; all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token; gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;