Merge pull request #127 from szabadka:gemma3

PiperOrigin-RevId: 621815677
This commit is contained in:
Copybara-Service 2024-04-04 04:32:03 -07:00
commit 08948f13ac
2 changed files with 15 additions and 5 deletions

View File

@ -662,12 +662,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_gen_start = pos_offset; size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset); int token = prompt.at(pos_offset);
stream_token(token, 0);
for (size_t generate_pos = 0; for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
float* final_activation = activations.x.data(); float* final_activation = activations.x.data();
if (pos_offset >= prompt_size) { // The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
// is disabled.
if (pos_offset >= prompt_size - 1) {
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Generation phase // Generation phase
MatVec<kVocabSize, TConfig::kModelDim>( MatVec<kVocabSize, TConfig::kModelDim>(
@ -677,9 +681,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize, token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token); gen, temperature, accept_token);
} if (!stream_token(token, activations.logits[token])) {
if (!stream_token(token, activations.logits[token])) { token = EOS_ID;
token = EOS_ID; }
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
} }
if (token == EOS_ID) { if (token == EOS_ID) {
if (verbosity >= 2) { if (verbosity >= 2) {

3
run.cc
View File

@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
verbosity](int token, float) { verbosity](int token, float) {
++abs_pos; ++abs_pos;
++current_pos; ++current_pos;
if (current_pos < prompt_size) { // <= since position is incremented before
if (current_pos <= prompt_size) {
std::cerr << "." << std::flush; std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) { } else if (token == gcpp::EOS_ID) {
if (!args.multiturn) { if (!args.multiturn) {