mirror of https://github.com/google/gemma.cpp.git
parent
960ff4b4ec
commit
f519ab6693
|
|
@ -354,7 +354,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
static_assert(!TConfig::kAbsolutePE);
|
||||
static_assert(!TConfig::kPostNormScale);
|
||||
static_assert(TConfig::kPostNorm == PostNormType::None);
|
||||
static_assert(TConfig::kKVHeads == 1);
|
||||
|
||||
HWY_DASSERT(prompt.context_size > 0);
|
||||
|
|
|
|||
|
|
@ -377,7 +377,7 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
|||
}
|
||||
}
|
||||
|
||||
struct TestConfig : ConfigCapNoSSM {
|
||||
struct TestConfig : ConfigBaseGemmaV2 {
|
||||
static constexpr int kSeqLen = 18;
|
||||
static constexpr int kVocabSize = 12;
|
||||
static constexpr int kModelDim = 32;
|
||||
|
|
@ -388,7 +388,7 @@ struct TestConfig : ConfigCapNoSSM {
|
|||
FixedLayerConfig<2>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
|
|
|
|||
|
|
@ -182,7 +182,7 @@ void TestRMSNormVJP() {
|
|||
}
|
||||
}
|
||||
|
||||
struct TestConfig : public ConfigCapNoSSM {
|
||||
struct TestConfig : public ConfigBaseGemmaV2 {
|
||||
static constexpr int kSeqLen = 24;
|
||||
static constexpr int kVocabSize = 16;
|
||||
static constexpr int kModelDim = 32;
|
||||
|
|
@ -193,7 +193,7 @@ struct TestConfig : public ConfigCapNoSSM {
|
|||
FixedLayerConfig<2>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
|||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
static_assert(!TConfig::kAbsolutePE);
|
||||
static_assert(!TConfig::kPostNormScale);
|
||||
static_assert(TConfig::kPostNorm == PostNormType::None);
|
||||
static_assert(TConfig::kKVHeads == 1);
|
||||
|
||||
HWY_DASSERT(context_size > 0);
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ struct LoadRawWeightsT {
|
|||
SCALE_WEIGHTS(linear_w);
|
||||
READ_WEIGHTS(pre_attention_norm_scale);
|
||||
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||
if (TConfig::kPostNormScale) {
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
READ_WEIGHTS(post_attention_norm_scale);
|
||||
READ_WEIGHTS(post_ffw_norm_scale);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -48,7 +49,7 @@ struct Layer {
|
|||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
|
|
@ -80,8 +81,10 @@ struct Layer {
|
|||
std::array<T, kModelDim * kFFHiddenDim> linear_w;
|
||||
std::array<T, kModelDim> pre_attention_norm_scale;
|
||||
std::array<T, kModelDim> pre_ffw_norm_scale;
|
||||
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
|
||||
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
|
||||
post_attention_norm_scale;
|
||||
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
|
||||
post_ffw_norm_scale;
|
||||
|
||||
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
|
|
|
|||
|
|
@ -244,6 +244,13 @@ static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
|||
Sqrt(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() {
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
// QueryScaleType::Sqrt
|
||||
return 1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
|
|
|||
|
|
@ -52,6 +52,32 @@ enum class LayerAttentionType {
|
|||
kGriffinRecurrentBlock,
|
||||
};
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
enum class PostNormType {
|
||||
None,
|
||||
Scale,
|
||||
};
|
||||
|
||||
// Post qk projection operation type.
|
||||
enum class PostQKType {
|
||||
Rope,
|
||||
};
|
||||
|
||||
// FFW activation function.
|
||||
enum class ActivationType {
|
||||
Gelu,
|
||||
};
|
||||
|
||||
// Attention query scale.
|
||||
enum class QueryScaleType {
|
||||
Sqrt,
|
||||
};
|
||||
|
||||
// Residual connection type.
|
||||
enum class ResidualType {
|
||||
Add,
|
||||
};
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
||||
LayerAttentionType type) {
|
||||
|
|
@ -107,21 +133,27 @@ struct ConfigNoSSM {
|
|||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr int kNumTensorScales = 0;
|
||||
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
struct ConfigNoCapNoSSM : ConfigNoSSM {
|
||||
struct ConfigBaseGemmaV1 : ConfigNoSSM {
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
|
||||
// For Gemma2 with SoftCap
|
||||
struct ConfigCapNoSSM : ConfigNoSSM {
|
||||
struct ConfigBaseGemmaV2 : ConfigNoSSM {
|
||||
static constexpr float kAttCap = 50.0f;
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::Scale;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma27B : public ConfigCapNoSSM {
|
||||
struct ConfigGemma27B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
|
|
@ -143,11 +175,10 @@ struct ConfigGemma27B : public ConfigCapNoSSM {
|
|||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = true;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma9B : public ConfigCapNoSSM {
|
||||
struct ConfigGemma9B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
|
|
@ -169,11 +200,10 @@ struct ConfigGemma9B : public ConfigCapNoSSM {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = true;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma7B : public ConfigNoCapNoSSM {
|
||||
struct ConfigGemma7B : public ConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -191,11 +221,10 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2B : public ConfigNoCapNoSSM {
|
||||
struct ConfigGemma2B : public ConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
|
|
@ -213,7 +242,6 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
|
|
@ -235,7 +263,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
|||
static constexpr int kQKVDim = 16; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
|
|
@ -294,7 +322,7 @@ struct ConfigGriffin2B {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
// No SoftCap.
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
|
|
@ -308,6 +336,10 @@ struct ConfigGriffin2B {
|
|||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -195,6 +195,13 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig, typename T>
|
||||
HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) {
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
// PostQKType::Rope
|
||||
Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void Attention(
|
||||
size_t batch_and_query_start, size_t num_tokens, size_t num_queries,
|
||||
|
|
@ -216,8 +223,7 @@ HWY_NOINLINE void Attention(
|
|||
constexpr size_t kHeads = TConfig::kHeads;
|
||||
constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
|
||||
1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||
GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale<TConfig>();
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum".
|
||||
constexpr bool kIsMHA = TActivations::kIsMHA;
|
||||
static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved
|
||||
|
|
@ -278,8 +284,7 @@ HWY_NOINLINE void Attention(
|
|||
// Skip past the Q part of `q`, and copy KV to `kv`.
|
||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||
}
|
||||
// Apply rope to K.
|
||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
PostQK<TConfig>(kv, pos, layer);
|
||||
});
|
||||
|
||||
static_assert((kHeads % kKVHeads) == 0,
|
||||
|
|
@ -299,13 +304,16 @@ HWY_NOINLINE void Attention(
|
|||
|
||||
// Apply rope and scaling to Q.
|
||||
const size_t pos = batch_start + batch_idx;
|
||||
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||
PostQK<TConfig>(q, pos, layer);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + head * kSeqLen
|
||||
+ batch_and_query_idx * kHeads * kSeqLen;
|
||||
|
||||
|
||||
// Compute Q dot K scores
|
||||
const size_t start_pos =
|
||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
|
|
@ -372,6 +380,18 @@ HWY_NOINLINE void Attention(
|
|||
}
|
||||
}
|
||||
|
||||
template <class TConfig, typename T>
|
||||
HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2,
|
||||
size_t count) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<T>;
|
||||
using VF = hn::Vec<DF>;
|
||||
// ActivationType::Gelu
|
||||
hn::Transform1(DF(), c1, count, c2, [](DF df, VF v, VF mul) HWY_ATTR {
|
||||
return hn::Mul(mul, Gelu(df, v));
|
||||
});
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize>
|
||||
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
||||
size_t num_tokens,
|
||||
|
|
@ -400,14 +420,9 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b2, activations.C2.data(),
|
||||
pool);
|
||||
|
||||
// Gelu and multiply by gate.
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
|
||||
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
|
||||
return hn::Mul(mul, Gelu(df, v));
|
||||
});
|
||||
// Activation (Gelu) and multiply by gate.
|
||||
Activation<TConfig>(activations.C1.data(), activations.C2.data(),
|
||||
kFFHiddenDim * num_tokens);
|
||||
|
||||
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
|
||||
layer_weights->linear_w.data(),
|
||||
|
|
@ -431,12 +446,7 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
|
|||
layer_weights->gating_einsum_w, 0, vec,
|
||||
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
|
||||
[](DF df, VF v, VF mul)
|
||||
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
|
||||
Activation<TConfig>(out, out_mul, kFFHiddenDim);
|
||||
|
||||
MatVecT</*kAdd=*/true, kModelDim, kFFHiddenDim>(
|
||||
layer_weights->linear_w, 0,
|
||||
|
|
@ -467,6 +477,16 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
|
|||
};
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize, typename T>
|
||||
HWY_NOINLINE void ResidualConnection(
|
||||
size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
||||
const CompressedLayer<TConfig>* layer_weights, bool is_attention) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
// ResidualType::Add
|
||||
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries, other, x,
|
||||
kModelDim);
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
HWY_NOINLINE void TransformerLayer(
|
||||
size_t num_tokens, size_t num_queries, size_t pos, size_t layer,
|
||||
|
|
@ -497,29 +517,31 @@ HWY_NOINLINE void TransformerLayer(
|
|||
layer_weights, kv_caches, pool);
|
||||
}
|
||||
}
|
||||
if (TConfig::kPostNormScale) {
|
||||
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries,
|
||||
layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries,
|
||||
activations.att_post2.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
|
||||
ResidualConnection<TConfig, kBatchSize, kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.att_post2.data(),
|
||||
activations.x.data(), layer_weights, /*is_attention*/ true);
|
||||
RMSNormBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.x.data(),
|
||||
layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<TConfig, kBatchSize * kQueryBatchSize>(
|
||||
activations, num_tokens_and_queries, layer_weights, pool);
|
||||
if (TConfig::kPostNormScale) {
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
}
|
||||
AddFromBatched<kBatchSize * kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.ffw_out.data(),
|
||||
activations.x.data(), kModelDim);
|
||||
ResidualConnection<TConfig, kBatchSize, kQueryBatchSize>(
|
||||
num_tokens_and_queries, activations.ffw_out.data(), activations.x.data(),
|
||||
layer_weights, /*is_attention*/ false);
|
||||
}
|
||||
|
||||
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ struct CompressedLayer {
|
|||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
|
|
@ -86,9 +86,10 @@ struct CompressedLayer {
|
|||
// We don't yet have an RMSNorm that accepts all Weight.
|
||||
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
|
||||
ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
|
||||
post_attention_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
|
||||
post_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
|
|
@ -267,7 +268,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
|||
GEMMA_CALL_FUNC("gr_a", griffin.a);
|
||||
}
|
||||
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||
if (TConfig::kPostNormScale) {
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
|
||||
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
|
||||
}
|
||||
|
|
@ -331,7 +332,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
|||
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
|
||||
if (TConfig::kPostNormScale) { \
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) { \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
|
||||
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
|
||||
} \
|
||||
|
|
|
|||
Loading…
Reference in New Issue