mirror of https://github.com/google/gemma.cpp.git
Fix the position calculation issue in the generation phase
This commit is contained in:
parent
ea72575e56
commit
8c634f6486
|
|
@ -987,7 +987,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
TokenStreamer token_streamer(runtime_config);
|
TokenStreamer token_streamer(runtime_config);
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
gen_tokens[query_idx] = prompts[query_idx][prefill_per_query];
|
||||||
(void)token_streamer(query_idx_start + query_idx, prefill_per_query,
|
(void)token_streamer(query_idx_start + query_idx,
|
||||||
|
pos[query_idx] + prefill_per_query,
|
||||||
gen_tokens[query_idx], 0.0f);
|
gen_tokens[query_idx], 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1020,8 +1021,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
const int token = sample_token(logits, kVocabSize);
|
const int token = sample_token(logits, kVocabSize);
|
||||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||||
|
|
||||||
const bool is_eos = token_streamer(query_idx_start + query_idx,
|
const bool is_eos =
|
||||||
prefill_per_query + 1 + gen_per_query,
|
token_streamer(query_idx_start + query_idx,
|
||||||
|
pos[query_idx] + prefill_per_query + 1 + gen_per_query,
|
||||||
token, logits[token]);
|
token, logits[token]);
|
||||||
all_queries_eos &= is_eos;
|
all_queries_eos &= is_eos;
|
||||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : token;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue