From f519ab6693aac92f74a7082c3758f7acd9c55eaa Mon Sep 17 00:00:00 2001 From: Kan Wu Date: Wed, 10 Jul 2024 21:30:23 -0700 Subject: [PATCH] Refactor configurables. PiperOrigin-RevId: 651259154 --- backprop/backward-inl.h | 2 +- backprop/backward_scalar_test.cc | 4 +- backprop/backward_test.cc | 4 +- backprop/forward-inl.h | 2 +- compression/compress_weights.cc | 2 +- compression/weights_raw.h | 9 ++-- gemma/common.h | 7 +++ gemma/configs.h | 58 ++++++++++++++++++------ gemma/gemma-inl.h | 76 ++++++++++++++++++++------------ gemma/weights.h | 11 ++--- 10 files changed, 120 insertions(+), 55 deletions(-) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index fe693ca..492e4e7 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -354,7 +354,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, static constexpr size_t kLayers = TConfig::kLayers; const float kEmbScaling = EmbeddingScaling(); static_assert(!TConfig::kAbsolutePE); - static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kPostNorm == PostNormType::None); static_assert(TConfig::kKVHeads == 1); HWY_DASSERT(prompt.context_size > 0); diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index b261359..121d6a8 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -377,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) { } } -struct TestConfig : ConfigCapNoSSM { +struct TestConfig : ConfigBaseGemmaV2 { static constexpr int kSeqLen = 18; static constexpr int kVocabSize = 12; static constexpr int kModelDim = 32; @@ -388,7 +388,7 @@ struct TestConfig : ConfigCapNoSSM { FixedLayerConfig<2>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr int kKVHeads = 1; static constexpr int kGemmaLayers = kLayers; diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 94b164f..1e7dbd1 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -182,7 +182,7 @@ void TestRMSNormVJP() { } } -struct TestConfig : public ConfigCapNoSSM { +struct TestConfig : public ConfigBaseGemmaV2 { static constexpr int kSeqLen = 24; static constexpr int kVocabSize = 16; static constexpr int kModelDim = 32; @@ -193,7 +193,7 @@ struct TestConfig : public ConfigCapNoSSM { FixedLayerConfig<2>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr int kKVHeads = 1; static constexpr int kGemmaLayers = kLayers; diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index c24116f..636c23c 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, static constexpr size_t kLayers = TConfig::kLayers; const float kEmbScaling = EmbeddingScaling(); static_assert(!TConfig::kAbsolutePE); - static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kPostNorm == PostNormType::None); static_assert(TConfig::kKVHeads == 1); HWY_DASSERT(context_size > 0); diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index d57a2a1..2ca7400 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -159,7 +159,7 @@ struct LoadRawWeightsT { SCALE_WEIGHTS(linear_w); READ_WEIGHTS(pre_attention_norm_scale); READ_WEIGHTS(pre_ffw_norm_scale); - if (TConfig::kPostNormScale) { + if (TConfig::kPostNorm == PostNormType::Scale) { READ_WEIGHTS(post_attention_norm_scale); READ_WEIGHTS(post_ffw_norm_scale); } diff --git a/compression/weights_raw.h b/compression/weights_raw.h index 774c6f2..0819feb 100644 --- a/compression/weights_raw.h +++ b/compression/weights_raw.h @@ -27,6 +27,7 @@ #include #include "gemma/common.h" +#include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -48,7 +49,7 @@ struct Layer { static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr PostNormType kPostNorm = TConfig::kPostNorm; static constexpr size_t kAOBiasDim = TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; static constexpr size_t kGriffinDim = @@ -80,8 +81,10 @@ struct Layer { std::array linear_w; std::array pre_attention_norm_scale; std::array pre_ffw_norm_scale; - std::array post_attention_norm_scale; - std::array post_ffw_norm_scale; + std::array + post_attention_norm_scale; + std::array + post_ffw_norm_scale; std::array ffw_gating_biases; std::array ffw_output_biases; diff --git a/gemma/common.h b/gemma/common.h index 151c6b0..70c950c 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -244,6 +244,13 @@ static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling( Sqrt(static_cast(model_dim)))); } +template +GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() { + constexpr size_t kQKVDim = TConfig::kQKVDim; + // QueryScaleType::Sqrt + return 1.0f / Sqrt(static_cast(kQKVDim)); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/configs.h b/gemma/configs.h index b7e2a44..f1327d3 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -52,6 +52,32 @@ enum class LayerAttentionType { kGriffinRecurrentBlock, }; +// Post attention and ffw normalization type. +enum class PostNormType { + None, + Scale, +}; + +// Post qk projection operation type. +enum class PostQKType { + Rope, +}; + +// FFW activation function. +enum class ActivationType { + Gelu, +}; + +// Attention query scale. +enum class QueryScaleType { + Sqrt, +}; + +// Residual connection type. +enum class ResidualType { + Add, +}; + template constexpr std::array FixedLayerConfig( LayerAttentionType type) { @@ -107,21 +133,27 @@ struct ConfigNoSSM { static constexpr bool kUseLocalAttention = false; static constexpr bool kInterleaveQKV = true; static constexpr int kNumTensorScales = 0; + + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; + static constexpr ResidualType kResidual = ResidualType::Add; }; -struct ConfigNoCapNoSSM : ConfigNoSSM { +struct ConfigBaseGemmaV1 : ConfigNoSSM { static constexpr float kAttCap = 0.0f; static constexpr float kFinalCap = 0.0f; + static constexpr PostNormType kPostNorm = PostNormType::None; }; -// For Gemma2 with SoftCap -struct ConfigCapNoSSM : ConfigNoSSM { +struct ConfigBaseGemmaV2 : ConfigNoSSM { static constexpr float kAttCap = 50.0f; static constexpr float kFinalCap = 30.0f; + static constexpr PostNormType kPostNorm = PostNormType::Scale; }; template -struct ConfigGemma27B : public ConfigCapNoSSM { +struct ConfigGemma27B : public ConfigBaseGemmaV2 { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = 8192; @@ -143,11 +175,10 @@ struct ConfigGemma27B : public ConfigCapNoSSM { static constexpr int kQKVDim = 128; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = true; }; template -struct ConfigGemma9B : public ConfigCapNoSSM { +struct ConfigGemma9B : public ConfigBaseGemmaV2 { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = 8192; @@ -169,11 +200,10 @@ struct ConfigGemma9B : public ConfigCapNoSSM { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = true; }; template -struct ConfigGemma7B : public ConfigNoCapNoSSM { +struct ConfigGemma7B : public ConfigBaseGemmaV1 { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -191,11 +221,10 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; }; template -struct ConfigGemma2B : public ConfigNoCapNoSSM { +struct ConfigGemma2B : public ConfigBaseGemmaV1 { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -213,7 +242,6 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; }; template @@ -235,7 +263,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM { static constexpr int kQKVDim = 16; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr float kAttCap = 0.0f; // This is required for optimize_test to pass. @@ -294,7 +322,7 @@ struct ConfigGriffin2B { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; - static constexpr bool kPostNormScale = false; + static constexpr PostNormType kPostNorm = PostNormType::None; // No SoftCap. static constexpr float kAttCap = 0.0f; @@ -308,6 +336,10 @@ struct ConfigGriffin2B { static constexpr bool kUseLocalAttention = true; static constexpr bool kInterleaveQKV = false; static constexpr int kNumTensorScales = 140; + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; + static constexpr ResidualType kResidual = ResidualType::Add; }; } // namespace gcpp diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 18b4ac5..a3812da 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -195,6 +195,13 @@ HWY_NOINLINE void GriffinRecurrent( } } +template +HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) { + constexpr size_t kQKVDim = TConfig::kQKVDim; + // PostQKType::Rope + Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); +} + template HWY_NOINLINE void Attention( size_t batch_and_query_start, size_t num_tokens, size_t num_queries, @@ -216,8 +223,7 @@ HWY_NOINLINE void Attention( constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kSeqLen = TConfig::kSeqLen; - GEMMA_CONSTEXPR_SQRT const float kQueryScale = - 1.0f / Sqrt(static_cast(kQKVDim)); + GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale(); // Multi-Head Attention a.k.a. "use_qkv_einsum". constexpr bool kIsMHA = TActivations::kIsMHA; static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved @@ -278,8 +284,7 @@ HWY_NOINLINE void Attention( // Skip past the Q part of `q`, and copy KV to `kv`. memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); } - // Apply rope to K. - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + PostQK(kv, pos, layer); }); static_assert((kHeads % kKVHeads) == 0, @@ -299,13 +304,16 @@ HWY_NOINLINE void Attention( // Apply rope and scaling to Q. const size_t pos = batch_start + batch_idx; - Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + PostQK(q, pos, layer); MulByConst(kQueryScale, q, kQKVDim); // Compute Q.K scores, yielding "logits" (or scores) in head_att. float* HWY_RESTRICT head_att = activations.att.data() + head * kSeqLen + batch_and_query_idx * kHeads * kSeqLen; + + + // Compute Q dot K scores const size_t start_pos = pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { @@ -372,6 +380,18 @@ HWY_NOINLINE void Attention( } } +template +HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, + size_t count) { + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + // ActivationType::Gelu + hn::Transform1(DF(), c1, count, c2, [](DF df, VF v, VF mul) HWY_ATTR { + return hn::Mul(mul, Gelu(df, v)); + }); +} + template HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, @@ -400,14 +420,9 @@ HWY_NOINLINE void FFW(Activations& activations, MatMul_4x4_Batch(num_tokens, A, b2, activations.C2.data(), pool); - // Gelu and multiply by gate. - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - using VF = hn::Vec; - hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, - activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { - return hn::Mul(mul, Gelu(df, v)); - }); + // Activation (Gelu) and multiply by gate. + Activation(activations.C1.data(), activations.C2.data(), + kFFHiddenDim * num_tokens); MatMul_4x4_Batch(num_tokens, activations.C1.data(), layer_weights->linear_w.data(), @@ -431,12 +446,7 @@ HWY_NOINLINE void FFW(Activations& activations, layer_weights->gating_einsum_w, 0, vec, layer_weights->ffw_gating_biases.data(), even_odd, out, pool); - namespace hn = hwy::HWY_NAMESPACE; - using DF = hn::ScalableTag; - using VF = hn::Vec; - hn::Transform1(DF(), out, kFFHiddenDim, out_mul, - [](DF df, VF v, VF mul) - HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); + Activation(out, out_mul, kFFHiddenDim); MatVecT( layer_weights->linear_w, 0, @@ -467,6 +477,16 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, }; } +template +HWY_NOINLINE void ResidualConnection( + size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x, + const CompressedLayer* layer_weights, bool is_attention) { + constexpr size_t kModelDim = TConfig::kModelDim; + // ResidualType::Add + AddFromBatched(num_tokens_and_queries, other, x, + kModelDim); +} + template HWY_NOINLINE void TransformerLayer( size_t num_tokens, size_t num_queries, size_t pos, size_t layer, @@ -497,29 +517,31 @@ HWY_NOINLINE void TransformerLayer( layer_weights, kv_caches, pool); } } - if (TConfig::kPostNormScale) { + + if (TConfig::kPostNorm == PostNormType::Scale) { RMSNormInplaceBatched( num_tokens_and_queries, layer_weights->post_attention_norm_scale.data(), activations.att_post2.data(), kModelDim); } - AddFromBatched(num_tokens_and_queries, - activations.att_post2.data(), - activations.x.data(), kModelDim); + + ResidualConnection( + num_tokens_and_queries, activations.att_post2.data(), + activations.x.data(), layer_weights, /*is_attention*/ true); RMSNormBatched( num_tokens_and_queries, activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data(), kModelDim); FFW( activations, num_tokens_and_queries, layer_weights, pool); - if (TConfig::kPostNormScale) { + if (TConfig::kPostNorm == PostNormType::Scale) { RMSNormInplaceBatched( num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), activations.ffw_out.data(), kModelDim); } - AddFromBatched( - num_tokens_and_queries, activations.ffw_out.data(), - activations.x.data(), kModelDim); + ResidualConnection( + num_tokens_and_queries, activations.ffw_out.data(), activations.x.data(), + layer_weights, /*is_attention*/ false); } template diff --git a/gemma/weights.h b/gemma/weights.h index c0c33c8..ee4ab78 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -50,7 +50,7 @@ struct CompressedLayer { static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr PostNormType kPostNorm = TConfig::kPostNorm; static constexpr size_t kAOBiasDim = TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; static constexpr size_t kGriffinDim = @@ -86,9 +86,10 @@ struct CompressedLayer { // We don't yet have an RMSNorm that accepts all Weight. ArrayT pre_attention_norm_scale; ArrayT pre_ffw_norm_scale; - ArrayT + ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; + ArrayT + post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases; @@ -267,7 +268,7 @@ void ForEachTensor(RawWeightsPtr raw_weights, GEMMA_CALL_FUNC("gr_a", griffin.a); } GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); - if (TConfig::kPostNormScale) { + if (TConfig::kPostNorm == PostNormType::Scale) { GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); } @@ -331,7 +332,7 @@ void ForEachTensor(RawWeightsPtr raw_weights, GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ - if (TConfig::kPostNormScale) { \ + if (TConfig::kPostNorm == PostNormType::Scale) { \ GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ } \