From ce807a31a16a7ceb942bf93fb35aec14ce4b3916 Mon Sep 17 00:00:00 2001 From: Apoorv Reddy Date: Thu, 23 Jan 2025 05:28:51 -0800 Subject: [PATCH] internal change PiperOrigin-RevId: 718824952 --- gemma/gemma-inl.h | 72 ++++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 9aa3d11..d7b3a79 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -28,7 +28,6 @@ #include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/weights.h" -// Placeholder for internal test4, do not remove #include "paligemma/image.h" #include "util/allocator.h" #include "util/basics.h" @@ -1312,42 +1311,45 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, 0.0f); } - const size_t vocab_size = model.Config().vocab_size; - const double gen_start = hwy::platform::Now(); - for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - // Decode generates one token per query and increments queries_mutable_pos. - Transformer(QueriesToken(gen_tokens.data(), num_queries), - queries_mutable_pos, queries_prefix_end, weights, activations, - div_seq_len, kv_caches, runtime_config.layers_output, - runtime_config.activations_observer); - // queries_pos are incremented by Transformer. + { + const size_t vocab_size = model.Config().vocab_size; + const double gen_start = hwy::platform::Now(); + for (size_t gen = 0; gen < max_generated_tokens; ++gen) { + // Decode generates one token per query and increments + // queries_mutable_pos. + Transformer(QueriesToken(gen_tokens.data(), num_queries), + queries_mutable_pos, queries_prefix_end, weights, activations, + div_seq_len, kv_caches, runtime_config.layers_output, + runtime_config.activations_observer); + // queries_pos are incremented by Transformer. - bool all_queries_eos = true; - { - PROFILER_ZONE("Gen.EmbeddingMatmul"); - // Compute logits from last layer activations. - MatMul(ConstMatFromBatch(num_queries, activations.x), - ConstMatFromWeights(weights.embedder_input_embedding), - /*add=*/nullptr, *activations.env, - RowPtrFromBatch(activations.logits)); - } - PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); - MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); - const TokenAndProb tp = sample_token(logits, vocab_size); - timing_info.NotifyGenerated(prefill_start, gen_start); + bool all_queries_eos = true; + { + PROFILER_ZONE("Gen.EmbeddingMatmul"); + // Compute logits from last layer activations. + MatMul(ConstMatFromBatch(num_queries, activations.x), + ConstMatFromWeights(weights.embedder_input_embedding), + /*add=*/nullptr, *activations.env, + RowPtrFromBatch(activations.logits)); + } + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); + MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, + vocab_size); + const TokenAndProb tp = sample_token(logits, vocab_size); + timing_info.NotifyGenerated(prefill_start, gen_start); - const bool is_eos = - token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], tp.token, tp.prob); - all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; - } - if (all_queries_eos) break; - } // foreach token to generate - - timing_info.NotifyGenerateDone(gen_start); + const bool is_eos = + token_streamer(query_idx_start + query_idx, + queries_mutable_pos[query_idx], tp.token, tp.prob); + all_queries_eos &= is_eos; + gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; + } + if (all_queries_eos) break; + } // foreach token to generate + timing_info.NotifyGenerateDone(gen_start); + } } template