diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 4ca5f20..45edbc7 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -17,6 +17,7 @@ // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. +#include #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep @@ -592,16 +593,15 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { const int token = tokens[token_idx]; - HWY_ASSERT(token >= 0); - HWY_ASSERT(token < TConfig::kVocabSize); + HWY_DASSERT(token >= 0); + HWY_DASSERT(token < TConfig::kVocabSize); Decompress(weights.embedder_input_embedding, token * kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim); MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, kModelDim); if constexpr (TConfig::kAbsolutePE) { AddAbsolutePositionalEmbeddings( - activations.x.data() + token_idx * kModelDim, TConfig::kModelDim, - pos); + activations.x.data() + token_idx * kModelDim, kModelDim, pos); }; }); @@ -646,72 +646,92 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, activations.x.data(), kModelDim); } -// n = 1 specialization -template -HWY_NOINLINE void Transformer(int token, size_t pos, +// Compute the transformer for a batch of input tokens. During generation, +// we usually have num_tokens == 1 (and also kBatchSize == 1). +template +HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, - Activations& activations, + Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool, LayersOutputT* layers_output) { + HWY_ASSERT(num_tokens <= kBatchSize); if (layers_output != nullptr) { - float token_f = token; - (*layers_output)(pos, "Tokens", &token_f, 1); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + float token_f = tokens[token_idx]; + (*layers_output)(pos + token_idx, "Tokens", &token_f, 1); + } } static constexpr size_t kModelDim = TConfig::kModelDim; - Decompress(weights.embedder_input_embedding, token * kModelDim, - activations.x.data(), kModelDim); - GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); - MulByConst(kEmbScaling, activations.x.data(), kModelDim); - if constexpr (TConfig::kAbsolutePE) { - AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim, - pos); - }; + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + const int token = tokens[token_idx]; + HWY_DASSERT(token >= 0); + HWY_DASSERT(token < TConfig::kVocabSize); + Decompress(weights.embedder_input_embedding, token * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, + kModelDim); + if constexpr (TConfig::kAbsolutePE) { + AddAbsolutePositionalEmbeddings( + activations.x.data() + token_idx * kModelDim, kModelDim, + pos + token_idx); + }; + } + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer); size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNormBatched<1>(1, activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); + RMSNormBatched(num_tokens, activations.x.data(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), kModelDim); if (type == LayerAttentionType::kGemma) { - Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache, - pool); + Attention(pos, num_tokens, layer_of_type, activations, + layer_weights, kv_cache, pool); } else { - GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights, - kv_cache, pool); + GriffinRecurrent(pos, num_tokens, layer_of_type, activations, + layer_weights, kv_cache, pool); } if (TConfig::kPostNormScale) { - RMSNormInplaceBatched<1>(1, - layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data(), kModelDim); + RMSNormInplaceBatched( + num_tokens, layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data(), kModelDim); } - AddFromBatched<1>(1, activations.att_post2.data(), activations.x.data(), - kModelDim); - RMSNormBatched<1>(1, activations.x.data(), - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim); - FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool); + AddFromBatched(num_tokens, activations.att_post2.data(), + activations.x.data(), kModelDim); + RMSNormBatched(num_tokens, activations.x.data(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), + kModelDim); + FFW(activations, num_tokens, layer_weights, pool); if (TConfig::kPostNormScale) { - RMSNormInplaceBatched<1>(1, layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); + RMSNormInplaceBatched( + num_tokens, layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); } - AddFromBatched<1>(1, activations.ffw_out.data(), activations.x.data(), - kModelDim); + AddFromBatched(num_tokens, activations.ffw_out.data(), + activations.x.data(), kModelDim); if (layers_output != nullptr) { std::string block_name = "blocks." + std::to_string(layer); - (*layers_output)(pos, block_name, activations.x.data(), kModelDim); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + (*layers_output)(pos + token_idx, block_name, + activations.x.data() + token_idx * kModelDim, + kModelDim); + } } } // Placeholder for internal test4, do not remove - RMSNormInplaceBatched<1>(1, weights.final_norm_scale.data(), - activations.x.data(), kModelDim); + RMSNormInplaceBatched(num_tokens, weights.final_norm_scale.data(), + activations.x.data(), kModelDim); if (layers_output != nullptr) { - (*layers_output)(pos, "final_norm", activations.x.data(), kModelDim); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + (*layers_output)(pos + token_idx, "final_norm", + activations.x.data() + token_idx * kModelDim, kModelDim); + } } } @@ -781,6 +801,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, } HWY_ASSERT(prompt_size > 0); + // If no sample_func is provided, we use top-k sampling. const SampleFunc sample_token = runtime_config.sample_func ? runtime_config.sample_func @@ -827,36 +848,35 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, static_cast(pos_offset) / (prefill_end - prefill_start); } + // Start generation. const double gen_start = hwy::platform::Now(); - HWY_DASSERT(pos_offset == prompt_size - 1); - size_t pos_gen_start = pos_offset; int token = prompt.at(pos_offset); - runtime_config.stream_token(token, 0); + // The loop below is not yet prepared for batch size > 1. + HWY_ASSERT(kDecodeBatchSize == 1); + if (!runtime_config.stream_token(token, 0.0f)) return; for (size_t generate_pos = 0; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { - const bool is_generating_phase = pos_offset >= prompt_size - 1; - Transformer(token, pos, weights, activations, kv_cache, pool, - layers_output); - float* final_activation = activations.x.data(); + Transformer(&token, kDecodeBatchSize, pos, weights, + activations, kv_cache, pool, layers_output); + float token_logit = 0.0f; // The condition below is always true if we are doing Prefill above. // We keep it here for clarity so that the code is correct even if Prefill // is disabled. + const bool is_generating_phase = pos_offset >= prompt_size - 1; if (is_generating_phase) { PROFILER_ZONE("Gen.Embedding"); - // Generation phase + // Compute logits from last layer activations. MatVec( - weights.embedder_input_embedding, 0, final_activation, + weights.embedder_input_embedding, 0, activations.x.data(), activations.even_odd.data(), activations.logits.data(), pool); LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = sample_token(activations.logits.data(), kVocabSize); - if (!runtime_config.stream_token(token, activations.logits[token])) { - token = runtime_config.eos_id; - } + token_logit = activations.logits[token]; if (generate_pos == 0) { timing_info.time_to_first_token = hwy::platform::Now() - gen_start; } @@ -864,20 +884,19 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, // We would take this branch if we were not doing Prefill but would // process the tokens of the prompt one at a time. token = prompt.at(pos_offset + 1); - if (!runtime_config.stream_token(token, 0)) { - token = runtime_config.eos_id; - } + } + if (!runtime_config.stream_token(token, token_logit)) { + token = runtime_config.eos_id; } if (token == runtime_config.eos_id) { - if (runtime_config.verbosity >= 2) { - const double gen_end = hwy::platform::Now(); - timing_info.gen_tok_sec = - static_cast(pos_offset - pos_gen_start) / - (gen_end - gen_start); - } break; } } + if (runtime_config.verbosity >= 2) { + const double gen_end = hwy::platform::Now(); + timing_info.gen_tok_sec = + static_cast(pos_offset - pos_gen_start) / (gen_end - gen_start); + } } } // namespace HWY_NAMESPACE @@ -898,7 +917,7 @@ struct AllocatePrefill { template struct AllocateDecode { ByteStorageT operator()() const { - return AllocateSizeof>(); + return AllocateSizeof>(); } }; } // namespace diff --git a/gemma/gemma.h b/gemma/gemma.h index 4e896ed..403740f 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -31,6 +31,7 @@ namespace gcpp { constexpr size_t kPrefillBatchSize = 16; +constexpr size_t kDecodeBatchSize = 1; constexpr bool kSystemPrompt = false; struct KVCache {