Add configurables for norm/rope/activation/scale/residual connection.

PiperOrigin-RevId: 648971168
This commit is contained in:
Kan Wu 2024-07-03 00:29:01 -07:00 committed by Copybara-Service
parent 7e4b20455e
commit cca75c5c60
9 changed files with 85 additions and 34 deletions

View File

@ -355,7 +355,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>(); const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE); static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale); static_assert(TConfig::kPostNorm == PostNormType::None);
static_assert(TConfig::kKVHeads == 1); static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(prompt.context_size > 0); HWY_DASSERT(prompt.context_size > 0);

View File

@ -388,7 +388,7 @@ struct TestConfig : ConfigCapNoSSM {
FixedLayerConfig<2>(LayerAttentionType::kGemma); FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;

View File

@ -193,7 +193,7 @@ struct TestConfig : public ConfigCapNoSSM {
FixedLayerConfig<2>(LayerAttentionType::kGemma); FixedLayerConfig<2>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;

View File

@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kLayers = TConfig::kLayers;
const float kEmbScaling = EmbeddingScaling<TConfig>(); const float kEmbScaling = EmbeddingScaling<TConfig>();
static_assert(!TConfig::kAbsolutePE); static_assert(!TConfig::kAbsolutePE);
static_assert(!TConfig::kPostNormScale); static_assert(TConfig::kPostNorm == PostNormType::None);
static_assert(TConfig::kKVHeads == 1); static_assert(TConfig::kKVHeads == 1);
HWY_DASSERT(context_size > 0); HWY_DASSERT(context_size > 0);

View File

@ -52,6 +52,32 @@ enum class LayerAttentionType {
kGriffinRecurrentBlock, kGriffinRecurrentBlock,
}; };
// Post attention and ffw normalization type.
enum class PostNormType {
None,
Scale,
};
// Post qk projection operation type.
enum class PostQKType {
Rope,
};
// FFW activation function.
enum class ActivationType {
Gelu,
};
// Attention query scale.
enum class QueryScaleType {
Sqrt,
};
// Residual connection type.
enum class ResidualType {
Add,
};
template <size_t kNum> template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig( constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
LayerAttentionType type) { LayerAttentionType type) {
@ -107,6 +133,11 @@ struct ConfigNoSSM {
static constexpr bool kUseLocalAttention = false; static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true; static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0; static constexpr int kNumTensorScales = 0;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr ResidualType kResidual = ResidualType::Add;
}; };
struct ConfigNoCapNoSSM : ConfigNoSSM { struct ConfigNoCapNoSSM : ConfigNoSSM {
@ -143,7 +174,7 @@ struct ConfigGemma27B : public ConfigCapNoSSM {
static constexpr int kQKVDim = 128; // query size == key size == value size static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = true; static constexpr PostNormType kPostNorm = PostNormType::Scale;
}; };
template <typename TWeight> template <typename TWeight>
@ -169,7 +200,7 @@ struct ConfigGemma9B : public ConfigCapNoSSM {
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = true; static constexpr PostNormType kPostNorm = PostNormType::Scale;
}; };
template <typename TWeight> template <typename TWeight>
@ -191,7 +222,7 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr PostNormType kPostNorm = PostNormType::None;
}; };
template <typename TWeight> template <typename TWeight>
@ -213,7 +244,7 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr PostNormType kPostNorm = PostNormType::None;
}; };
template <typename TWeight> template <typename TWeight>
@ -235,7 +266,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kQKVDim = 16; // query size == key size == value size static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr float kAttCap = 0.0f; static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass. // This is required for optimize_test to pass.
@ -294,7 +325,7 @@ struct ConfigGriffin2B {
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = false; static constexpr PostNormType kPostNorm = PostNormType::None;
// No SoftCap. // No SoftCap.
static constexpr float kAttCap = 0.0f; static constexpr float kAttCap = 0.0f;
@ -308,6 +339,9 @@ struct ConfigGriffin2B {
static constexpr bool kUseLocalAttention = true; static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false; static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140; static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr ResidualType kResidual = ResidualType::Add;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -295,8 +295,9 @@ HWY_NOINLINE void Attention(
constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kKVHeads = TConfig::kKVHeads;
constexpr size_t kSeqLen = TConfig::kSeqLen; constexpr size_t kSeqLen = TConfig::kSeqLen;
GEMMA_CONSTEXPR_SQRT const float kQueryScale = GEMMA_CONSTEXPR_SQRT float kQueryScale =
1.0f / Sqrt(static_cast<float>(kQKVDim)); 1.0f / Sqrt(static_cast<float>(kQKVDim));
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
const size_t batch_start = batch_and_query_start / num_queries; const size_t batch_start = batch_and_query_start / num_queries;
const size_t num_tokens_and_queries = num_tokens * num_queries; const size_t num_tokens_and_queries = num_tokens * num_queries;
@ -350,7 +351,9 @@ HWY_NOINLINE void Attention(
// Skip past the Q part of `q`, and copy KV to `kv`. // Skip past the Q part of `q`, and copy KV to `kv`.
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
} }
if (TConfig::kPostQK == PostQKType::Rope) {
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
}); });
static_assert((kHeads % kKVHeads) == 0, static_assert((kHeads % kKVHeads) == 0,
@ -373,7 +376,10 @@ HWY_NOINLINE void Attention(
activations.att.data() + head * kSeqLen activations.att.data() + head * kSeqLen
+ batch_and_query_idx * kHeads * kSeqLen; + batch_and_query_idx * kHeads * kSeqLen;
if (TConfig::kPostQK == PostQKType::Rope) {
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
}
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
@ -465,10 +471,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
if (TConfig::kActivation == ActivationType::Gelu) {
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v)); return hn::Mul(mul, Gelu(df, v));
}); });
}
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(), MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
layer_weights->linear_w.data(), layer_weights->linear_w.data(),
@ -560,29 +568,34 @@ HWY_NOINLINE void TransformerLayer(
layer_weights, kv_caches, pool); layer_weights, kv_caches, pool);
} }
} }
if (TConfig::kPostNormScale) {
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>( RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data(), layer_weights->post_attention_norm_scale.data(),
activations.att_post2.data(), kModelDim); activations.att_post2.data(), kModelDim);
} }
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries, if (TConfig::kResidual == ResidualType::Add) {
activations.att_post2.data(), AddFromBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.att_post2.data(),
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
}
RMSNormBatched<kBatchSize * kQueryBatchSize>( RMSNormBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.x.data(), num_tokens_and_queries, activations.x.data(),
layer_weights->pre_ffw_norm_scale.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<TConfig, kBatchSize * kQueryBatchSize>( FFW<TConfig, kBatchSize * kQueryBatchSize>(
activations, num_tokens_and_queries, layer_weights, pool); activations, num_tokens_and_queries, layer_weights, pool);
if (TConfig::kPostNormScale) { if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>( RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
activations.ffw_out.data(), kModelDim); activations.ffw_out.data(), kModelDim);
} }
if (TConfig::kResidual == ResidualType::Add) {
AddFromBatched<kBatchSize * kQueryBatchSize>( AddFromBatched<kBatchSize * kQueryBatchSize>(
num_tokens_and_queries, activations.ffw_out.data(), num_tokens_and_queries, activations.ffw_out.data(),
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
}
} }
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize> template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>

View File

@ -50,7 +50,7 @@ struct CompressedLayer {
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases; static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale; static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
static constexpr size_t kAOBiasDim = static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim = static constexpr size_t kGriffinDim =
@ -86,9 +86,10 @@ struct CompressedLayer {
// We don't yet have an RMSNorm that accepts all Weight. // We don't yet have an RMSNorm that accepts all Weight.
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale; ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale; ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_attention_norm_scale; post_attention_norm_scale;
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale; ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases; ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases; ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
@ -267,7 +268,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_FUNC("gr_a", griffin.a); GEMMA_CALL_FUNC("gr_a", griffin.a);
} }
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
if (TConfig::kPostNormScale) { if (TConfig::kPostNorm == PostNormType::Scale) {
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
} }
@ -331,7 +332,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
if (TConfig::kPostNormScale) { \ if (TConfig::kPostNorm == PostNormType::Scale) { \
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
} \ } \

View File

@ -25,6 +25,7 @@
#include <random> #include <random>
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -46,7 +47,7 @@ struct Layer {
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases; static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale; static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
static constexpr size_t kAOBiasDim = static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim = static constexpr size_t kGriffinDim =
@ -78,8 +79,10 @@ struct Layer {
std::array<T, kModelDim * kFFHiddenDim> linear_w; std::array<T, kModelDim * kFFHiddenDim> linear_w;
std::array<T, kModelDim> pre_attention_norm_scale; std::array<T, kModelDim> pre_attention_norm_scale;
std::array<T, kModelDim> pre_ffw_norm_scale; std::array<T, kModelDim> pre_ffw_norm_scale;
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale; std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale; post_attention_norm_scale;
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
post_ffw_norm_scale;
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases; std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases; std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;

View File

@ -159,7 +159,7 @@ struct LoadRawWeightsT {
SCALE_WEIGHTS(linear_w); SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale); READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale); READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) { if (TConfig::kPostNorm == PostNormType::Scale) {
READ_WEIGHTS(post_attention_norm_scale); READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale); READ_WEIGHTS(post_ffw_norm_scale);
} }