diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bd0f062..ff551c2 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -546,6 +546,37 @@ HWY_NOINLINE void FFW(Activations& activations, } } +// The below "batched" versions are just simple loops for now. +template +static void RMSNormBatched(size_t num_tokens, const float* activations, + const WeightT* weights, OutT* out, + const size_t model_dim) { + HWY_DASSERT(num_tokens <= kBatchSize); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNorm(activations + token_idx * model_dim, weights, + out + token_idx * model_dim, model_dim); + } +} + +template +static void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, + InOutT* inout, const size_t model_dim) { + HWY_DASSERT(num_tokens <= kBatchSize); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); + } +} + +template +static void AddFromBatched(size_t num_tokens, const float* other, float* x, + const size_t model_dim) { + HWY_DASSERT(num_tokens <= kBatchSize); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, + model_dim); + } +} + // Placeholder for internal test3, do not remove template @@ -580,12 +611,9 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNorm(activations.x.data() + token_idx * kModelDim, - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data() + token_idx * kModelDim, - kModelDim); - } + RMSNormBatched(num_tokens, activations.x.data(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), kModelDim); if (type == LayerAttentionType::kGemma) { Attention(pos, num_tokens, layer_of_type, activations, layer_weights, kv_cache, pool); @@ -593,38 +621,29 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, GriffinRecurrent(pos, num_tokens, layer_of_type, activations, layer_weights, kv_cache, pool); } - - pool.Run(0, num_tokens, [&](const uint64_t token_idx, - size_t /*thread*/) HWY_ATTR { - if (TConfig::kPostNormScale) { - RMSNormInplace(layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data() + token_idx * kModelDim, - kModelDim); - } - AddFrom(activations.att_post2.data() + token_idx * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - RMSNorm(activations.x.data() + token_idx * kModelDim, - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data() + token_idx * kModelDim, - kModelDim); - }); - FFW(activations, num_tokens, layer_weights, pool); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - if (TConfig::kPostNormScale) { - RMSNormInplace(layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data() + token_idx * kModelDim, - kModelDim); - } - AddFrom(activations.ffw_out.data() + token_idx * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); + if (TConfig::kPostNormScale) { + RMSNormInplaceBatched( + num_tokens, layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data(), kModelDim); } + AddFromBatched(num_tokens, activations.att_post2.data(), + activations.x.data(), kModelDim); + RMSNormBatched(num_tokens, activations.x.data(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), + kModelDim); + FFW(activations, num_tokens, layer_weights, pool); + if (TConfig::kPostNormScale) { + RMSNormInplaceBatched( + num_tokens, layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); + } + AddFromBatched(num_tokens, activations.ffw_out.data(), + activations.x.data(), kModelDim); } // foreach layer - pool.Run( - 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { - RMSNormInplace(weights.final_norm_scale.data(), - activations.x.data() + token_idx * kModelDim, kModelDim); - }); + RMSNormInplaceBatched(num_tokens, weights.final_norm_scale.data(), + activations.x.data(), kModelDim); } // n = 1 specialization @@ -654,9 +673,9 @@ HWY_NOINLINE void Transformer(int token, size_t pos, const auto* layer_weights = weights.GetLayer(layer); size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNorm(activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); + RMSNormBatched<1>(1, activations.x.data(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.data(), kModelDim); if (type == LayerAttentionType::kGemma) { Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache, pool); @@ -665,18 +684,22 @@ HWY_NOINLINE void Transformer(int token, size_t pos, kv_cache, pool); } if (TConfig::kPostNormScale) { - RMSNormInplace(layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data(), kModelDim); + RMSNormInplaceBatched<1>(1, + layer_weights->post_attention_norm_scale.data(), + activations.att_post2.data(), kModelDim); } - AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim); - RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim); + AddFromBatched<1>(1, activations.att_post2.data(), activations.x.data(), + kModelDim); + RMSNormBatched<1>(1, activations.x.data(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.data(), kModelDim); FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool); if (TConfig::kPostNormScale) { - RMSNormInplace(layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); + RMSNormInplaceBatched<1>(1, layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.data(), kModelDim); } - AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim); + AddFromBatched<1>(1, activations.ffw_out.data(), activations.x.data(), + kModelDim); if (layers_output != nullptr) { std::string block_name = "blocks." + std::to_string(layer); (*layers_output)(pos, block_name, activations.x.data(), kModelDim); @@ -685,8 +708,8 @@ HWY_NOINLINE void Transformer(int token, size_t pos, // Placeholder for internal test4, do not remove - RMSNormInplace(weights.final_norm_scale.data(), activations.x.data(), - kModelDim); + RMSNormInplaceBatched<1>(1, weights.final_norm_scale.data(), + activations.x.data(), kModelDim); if (layers_output != nullptr) { (*layers_output)(pos, "final_norm", activations.x.data(), kModelDim); } diff --git a/gemma/ops.h b/gemma/ops.h index 702efc5..eea838d 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -942,18 +942,20 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( return hn::ReduceSum(d, hn::Add(sum0, sum1)); } +// float, float -> float; simple loop. static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, float* HWY_RESTRICT out, size_t size) { - constexpr float eps = 1e-6f; + constexpr float kEps = 1e-6f; float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here out[j] = (1.0f + weight[j]) * (ss * x[j]); } } +// x=f, w=bf16 -> out=f static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT out, size_t size) { @@ -984,11 +986,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( } } +// float -> float; simple loop. static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { - constexpr float eps = 1e-6f; + constexpr float kEps = 1e-6f; float ss = SquaredL2(inout, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); + ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); for (size_t j = 0; j < size; j++) { // Note 1.0f centering here inout[j] = (1.0f + weight[j]) * (ss * inout[j]); @@ -1005,10 +1008,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( using VF = hn::Vec; const size_t N32 = hn::Lanes(df32); - constexpr float eps = 1e-6f; + constexpr float kEps = 1e-6f; const float ss = SquaredL2(inout, size); const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -1034,10 +1037,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( using VF = hn::Vec; const size_t N32 = hn::Lanes(df32); - constexpr float eps = 1e-6f; + constexpr float kEps = 1e-6f; const float ss = SquaredL2(x, size); const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) { @@ -1062,10 +1065,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( using VF = hn::Vec; const size_t N32 = hn::Lanes(df32); - constexpr float eps = 1e-6f; + constexpr float kEps = 1e-6f; const float ss = SquaredL2(x, size); const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); for (size_t i = 0; i < size; i += 2 * N32) {