diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 2dd1d63..f36730f 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -353,9 +353,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, activations.att_post2.Batch(interleaved_idx); // Head 0 (and potentially biases) -> layer_out. // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. - MatVecT( - layer_weights->attn_vec_einsum_w, 0, att_out, - layer_weights->attention_output_biases.data_scale1(), + constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases; + const float* bias = + kAdd ? layer_weights->attention_output_biases.data_scale1() : nullptr; + MatVecT( + layer_weights->attn_vec_einsum_w, 0, att_out, bias, activations.even_odd.All(), layer_out, pool); // Head 1 and following are added to layer_out. for (size_t head = 1; head < kHeads; ++head) { @@ -425,8 +427,14 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, auto C1 = activations.C1.All(); auto C2 = activations.C2.All(); constexpr bool kAddBias = TConfig::kFFBiases; - const auto bias1 = layer_weights->ffw_gating_biases.data_scale1(); - const auto bias2 = bias1 + kFFHiddenDim; + const float* bias1 = nullptr; + const float* bias2 = nullptr; + const float* output_bias = nullptr; + if constexpr (kAddBias) { + bias1 = layer_weights->ffw_gating_biases.data_scale1(); + bias2 = bias1 + kFFHiddenDim; + output_bias = layer_weights->ffw_output_biases.data_scale1(); + } // Will go through GELU. MatMul_4x4_Batch_Add(num_interleaved, A, B1, scale, @@ -442,7 +450,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, MatMul_4x4_Batch_Add( num_interleaved, C1, layer_weights->linear_w.data(), layer_weights->linear_w.scale(), activations.ffw_out.All(), - layer_weights->ffw_output_biases.data_scale1(), pool); + output_bias, pool); } // TODO: pass Activations.x instead of Activations. @@ -477,9 +485,10 @@ HWY_NOINLINE void ResidualConnection( } template -void PostNorm(size_t num_interleaved, const WeightT* weights, InOutT* inout) { +void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) { if (TConfig::kPostNorm == PostNormType::Scale) { - RMSNormInplaceBatched(num_interleaved, weights, inout, TConfig::kModelDim); + RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout, + TConfig::kModelDim); } } @@ -501,8 +510,7 @@ HWY_NOINLINE void TransformerLayer( Attention(type, pos, num_tokens, num_queries, layer_of_type, activations, layer_weights, kv_caches, pool); - PostNorm(num_interleaved, - layer_weights->post_attention_norm_scale.data_scale1(), + PostNorm(num_interleaved, layer_weights->post_attention_norm_scale, activations.att_post2.All()); ResidualConnection(num_interleaved, activations.att_post2.All(), @@ -515,8 +523,7 @@ HWY_NOINLINE void TransformerLayer( FFW(activations, num_interleaved, layer_weights, pool); - PostNorm(num_interleaved, - layer_weights->post_ffw_norm_scale.data_scale1(), + PostNorm(num_interleaved, layer_weights->post_ffw_norm_scale, activations.ffw_out.All()); ResidualConnection(num_interleaved, activations.ffw_out.All(), diff --git a/gemma/weights.h b/gemma/weights.h index cfc981c..6e0782d 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -230,6 +230,11 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); // // This avoids repeating the list of tensors between loading and compressing, // while also avoiding dependency on raw_weights.h. +// +// This only calls Func for tensors that TConfig requests/specifies, which means +// scale() is uninitialized for the other tensors, so their data_scale1() must +// not be called. (In other words, if the config doesn't specify a tensor, it +// shouldn't be used.) template void ForEachTensor(RawWeightsPtr raw_weights, CompressedWeights& c_weights, Func& func) { @@ -269,33 +274,19 @@ void ForEachTensor(RawWeightsPtr raw_weights, } GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); - // For conditionally-included tensors, the else branch must ensure their - // scale is initialized, because wrapper functions call data_scale1 even if - // the tensor turns out to be unused. If unused, the arrays are zero-length - // and data() returns a non-null but unusable pointer. - if (TConfig::kPostNorm == PostNormType::Scale) { GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); - } else { - c_layer->post_attention_norm_scale.set_scale(1.0f); - c_layer->post_ffw_norm_scale.set_scale(1.0f); } if (TConfig::kFFBiases) { GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases); GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases); - } else { - c_layer->ffw_gating_biases.set_scale(1.0f); - c_layer->ffw_output_biases.set_scale(1.0f); } - if (type == LayerAttentionType::kGemma) { - if (TConfig::kSoftmaxAttnOutputBiases) { - GEMMA_CALL_FUNC("attn_ob", attention_output_biases); - } else { - c_layer->attention_output_biases.set_scale(1.0f); - } + if (TConfig::kSoftmaxAttnOutputBiases && + type == LayerAttentionType::kGemma) { + GEMMA_CALL_FUNC("attn_ob", attention_output_biases); } } #undef GEMMA_CALL_FUNC