mirror of https://github.com/google/gemma.cpp.git
Fix the position calculation issue in the generation phase
This commit is contained in:
parent
ea72575e56
commit
8c634f6486
|
|
@ -987,7 +987,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
TokenStreamer token_streamer(runtime_config);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
||||
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
||||
(void)token_streamer(query_idx_start + query_idx,
|
||||
pos[query_idx] + prefill_per_query,
|
||||
gen_tokens[query_idx], 0.0f);
|
||||
}
|
||||
|
||||
|
|
@ -1020,9 +1021,10 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const int token = sample_token(logits, kVocabSize);
|
||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||
|
||||
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
||||
prefill_per_query + 1 + gen_per_query,
|
||||
token, logits[token]);
|
||||
const bool is_eos =
|
||||
token_streamer(query_idx_start + query_idx,
|
||||
pos[query_idx] + prefill_per_query + 1 + gen_per_query,
|
||||
token, logits[token]);
|
||||
all_queries_eos &= is_eos;
|
||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue