internal change

PiperOrigin-RevId: 718824952
This commit is contained in:
Apoorv Reddy 2025-01-23 05:28:51 -08:00 committed by Copybara-Service
parent a60b564b88
commit ce807a31a1
1 changed files with 37 additions and 35 deletions

View File

@ -28,7 +28,6 @@
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
// Placeholder for internal test4, do not remove
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" #include "util/basics.h"
@ -1312,42 +1311,45 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
0.0f); 0.0f);
} }
const size_t vocab_size = model.Config().vocab_size; {
const double gen_start = hwy::platform::Now(); const size_t vocab_size = model.Config().vocab_size;
for (size_t gen = 0; gen < max_generated_tokens; ++gen) { const double gen_start = hwy::platform::Now();
// Decode generates one token per query and increments queries_mutable_pos. for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
Transformer(QueriesToken(gen_tokens.data(), num_queries), // Decode generates one token per query and increments
queries_mutable_pos, queries_prefix_end, weights, activations, // queries_mutable_pos.
div_seq_len, kv_caches, runtime_config.layers_output, Transformer(QueriesToken(gen_tokens.data(), num_queries),
runtime_config.activations_observer); queries_mutable_pos, queries_prefix_end, weights, activations,
// queries_pos are incremented by Transformer. div_seq_len, kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
bool all_queries_eos = true; bool all_queries_eos = true;
{ {
PROFILER_ZONE("Gen.EmbeddingMatmul"); PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations. // Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x), MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding), ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env, /*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits)); RowPtrFromBatch(activations.logits));
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
const TokenAndProb tp = sample_token(logits, vocab_size); vocab_size);
timing_info.NotifyGenerated(prefill_start, gen_start); const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated(prefill_start, gen_start);
const bool is_eos = const bool is_eos =
token_streamer(query_idx_start + query_idx, token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob); queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos; all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
} }
if (all_queries_eos) break; if (all_queries_eos) break;
} // foreach token to generate } // foreach token to generate
timing_info.NotifyGenerateDone(gen_start);
timing_info.NotifyGenerateDone(gen_start); }
} }
template <typename T> template <typename T>