mirror of https://github.com/google/gemma.cpp.git
parent
f6d02b2870
commit
bacba351d4
|
|
@ -97,6 +97,7 @@ struct ConfigGemma7B {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
|
|
@ -128,6 +129,7 @@ struct ConfigGemma2B {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
|
|
@ -187,6 +189,7 @@ struct ConfigGriffin2B {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = false;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ struct Layer {
|
|||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
|
|
@ -121,6 +122,8 @@ struct Layer {
|
|||
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
|
||||
ArrayT<float, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<float, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<float, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
|
||||
ArrayT<float, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
|
|
@ -269,6 +272,10 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
SCALE_WEIGHTS(linear_w);
|
||||
READ_WEIGHTS(pre_attention_norm_scale);
|
||||
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||
if (TConfig::kPostNormScale) {
|
||||
READ_WEIGHTS(post_attention_norm_scale);
|
||||
READ_WEIGHTS(post_ffw_norm_scale);
|
||||
}
|
||||
if (TConfig::kFFBiases) {
|
||||
READ_WEIGHTS(ffw_gating_biases);
|
||||
READ_WEIGHTS(ffw_output_biases);
|
||||
|
|
@ -311,6 +318,7 @@ struct CompressedLayer {
|
|||
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
|
||||
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TLayer::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim;
|
||||
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
|
||||
|
||||
|
|
@ -346,6 +354,9 @@ struct CompressedLayer {
|
|||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0>
|
||||
post_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
|
|
@ -949,6 +960,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
|
||||
pool.Run(0, num_tokens, [&](const uint64_t token_idx,
|
||||
size_t /*thread*/) HWY_ATTR {
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplace(layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
}
|
||||
AddFrom(activations.att_post2.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
RMSNorm(activations.x.data() + token_idx * kModelDim,
|
||||
|
|
@ -958,6 +974,11 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
|
|||
});
|
||||
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
|
||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplace(layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data() + token_idx * kModelDim,
|
||||
kModelDim);
|
||||
}
|
||||
AddFrom(activations.ffw_out.data() + token_idx * kModelDim,
|
||||
activations.x.data() + token_idx * kModelDim, kModelDim);
|
||||
}
|
||||
|
|
@ -1005,10 +1026,18 @@ void Transformer(int token, size_t pos, const WeightArrayT& weights,
|
|||
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights,
|
||||
kv_cache, pool);
|
||||
}
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplace(layer_weights->post_attention_norm_scale.data(),
|
||||
activations.att_post2.data(), kModelDim);
|
||||
}
|
||||
AddFrom(activations.att_post2.data(), activations.x.data(), kModelDim);
|
||||
RMSNorm(activations.x.data(), layer_weights->pre_ffw_norm_scale.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim);
|
||||
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool);
|
||||
if (TConfig::kPostNormScale) {
|
||||
RMSNormInplace(layer_weights->post_ffw_norm_scale.data(),
|
||||
activations.ffw_out.data(), kModelDim);
|
||||
}
|
||||
AddFrom(activations.ffw_out.data(), activations.x.data(), kModelDim);
|
||||
if (layers_output != nullptr) {
|
||||
std::string block_name = "blocks." + std::to_string(layer);
|
||||
|
|
@ -1336,6 +1365,10 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
CALL_FUNC("gr_a", griffin.a);
|
||||
}
|
||||
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||
if (TConfig::kPostNormScale) {
|
||||
CALL_FUNC("post_att_ns", post_attention_norm_scale);
|
||||
CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
|
||||
}
|
||||
|
||||
if (TConfig::kFFBiases) {
|
||||
CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
||||
|
|
|
|||
Loading…
Reference in New Issue