diff --git a/gemma/gemma.cc b/gemma/gemma.cc index db601c9..fb874a4 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -602,70 +602,86 @@ static void AddFromBatched(size_t num_tokens, const float* other, float* x, // Placeholder for internal test3, do not remove +template +HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, + const WeightArrayT& weights, + Activations& activations) { + static constexpr size_t kModelDim = TConfig::kModelDim; + GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = + EmbeddingScaling(); + HWY_DASSERT(token >= 0); + HWY_DASSERT(token < TConfig::kVocabSize); + Decompress(weights.embedder_input_embedding, token * kModelDim, + activations.x.data() + token_idx * kModelDim, kModelDim); + MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, + kModelDim); + if constexpr (TConfig::kAbsolutePE) { + AddAbsolutePositionalEmbeddings( + activations.x.data() + token_idx * kModelDim, kModelDim, + pos + token_idx); + }; +} + +template +HWY_NOINLINE void TransformerLayer( + size_t num_tokens, size_t pos, size_t layer, + const LayerWeightArrayT* layer_weights, + Activations& activations, KVCache& kv_cache, + hwy::ThreadPool& pool) { + static constexpr size_t kModelDim = TConfig::kModelDim; + auto type = TConfig::kLayerConfig[layer]; + size_t layer_of_type = + NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); + RMSNormBatched(num_tokens, activations.x.data(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), kModelDim); + if (type == LayerAttentionType::kGemma) { + Attention(pos, num_tokens, layer_of_type, activations, + layer_weights, kv_cache, pool); + } else { + GriffinRecurrent(pos, num_tokens, layer_of_type, activations, + layer_weights, kv_cache, pool); + } + if (TConfig::kPostNormScale) { + RMSNormInplaceBatched( + num_tokens, layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data(), kModelDim); + } + AddFromBatched(num_tokens, activations.att_post2.data(), + activations.x.data(), kModelDim); + RMSNormBatched(num_tokens, activations.x.data(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), kModelDim); + FFW(activations, num_tokens, layer_weights, pool); + if (TConfig::kPostNormScale) { + RMSNormInplaceBatched(num_tokens, + layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); + } + AddFromBatched(num_tokens, activations.ffw_out.data(), + activations.x.data(), kModelDim); +} + template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); - static constexpr size_t kModelDim = TConfig::kModelDim; - GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = - EmbeddingScaling(); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { - const int token = tokens[token_idx]; - HWY_DASSERT(token >= 0); - HWY_DASSERT(token < TConfig::kVocabSize); - Decompress(weights.embedder_input_embedding, token * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, - kModelDim); - if constexpr (TConfig::kAbsolutePE) { - AddAbsolutePositionalEmbeddings( - activations.x.data() + token_idx * kModelDim, kModelDim, pos); - }; + EmbedToken(tokens[token_idx], token_idx, pos, weights, activations); }); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer); - size_t layer_of_type = - NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - - RMSNormBatched(num_tokens, activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); - if (type == LayerAttentionType::kGemma) { - Attention(pos, num_tokens, layer_of_type, activations, - layer_weights, kv_cache, pool); - } else { - GriffinRecurrent(pos, num_tokens, layer_of_type, activations, - layer_weights, kv_cache, pool); - } - if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens, layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data(), kModelDim); - } - AddFromBatched(num_tokens, activations.att_post2.data(), - activations.x.data(), kModelDim); - RMSNormBatched(num_tokens, activations.x.data(), - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), - kModelDim); - FFW(activations, num_tokens, layer_weights, pool); - if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens, layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); - } - AddFromBatched(num_tokens, activations.ffw_out.data(), - activations.x.data(), kModelDim); - } // foreach layer + TransformerLayer(num_tokens, pos, layer, layer_weights, activations, + kv_cache, pool); + } RMSNormInplaceBatched(num_tokens, weights.final_norm_scale.data(), - activations.x.data(), kModelDim); + activations.x.data(), TConfig::kModelDim); } // Compute the transformer for a batch of input tokens. During generation, @@ -684,57 +700,14 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos, } } static constexpr size_t kModelDim = TConfig::kModelDim; - GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = - EmbeddingScaling(); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - const int token = tokens[token_idx]; - HWY_DASSERT(token >= 0); - HWY_DASSERT(token < TConfig::kVocabSize); - Decompress(weights.embedder_input_embedding, token * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, - kModelDim); - if constexpr (TConfig::kAbsolutePE) { - AddAbsolutePositionalEmbeddings( - activations.x.data() + token_idx * kModelDim, kModelDim, - pos + token_idx); - }; + EmbedToken(tokens[token_idx], token_idx, pos, weights, activations); } for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - auto type = TConfig::kLayerConfig[layer]; const auto* layer_weights = weights.GetLayer(layer); - size_t layer_of_type = - NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNormBatched(num_tokens, activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); - if (type == LayerAttentionType::kGemma) { - Attention(pos, num_tokens, layer_of_type, activations, - layer_weights, kv_cache, pool); - } else { - GriffinRecurrent(pos, num_tokens, layer_of_type, activations, - layer_weights, kv_cache, pool); - } - if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens, layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data(), kModelDim); - } - AddFromBatched(num_tokens, activations.att_post2.data(), - activations.x.data(), kModelDim); - RMSNormBatched(num_tokens, activations.x.data(), - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), - kModelDim); - FFW(activations, num_tokens, layer_weights, pool); - if (TConfig::kPostNormScale) { - RMSNormInplaceBatched( - num_tokens, layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); - } - AddFromBatched(num_tokens, activations.ffw_out.data(), - activations.x.data(), kModelDim); + TransformerLayer(num_tokens, pos, layer, layer_weights, activations, + kv_cache, pool); if (layers_output) { std::string block_name = "blocks." + std::to_string(layer); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {