Fix msan uninitialized scale

PiperOrigin-RevId: 653655471
This commit is contained in:
Jan Wassenberg 2024-07-18 09:41:53 -07:00 committed by Copybara-Service
parent e87e65ca45
commit 3fe79b3876
1 changed files with 13 additions and 3 deletions

View File

@ -276,11 +276,21 @@ void ForEachTensor(RawWeightsPtr raw_weights,
if (TConfig::kFFBiases) { if (TConfig::kFFBiases) {
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases); GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases); GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
} else {
// Ensure initialized so we can call data_scale1, which happens even if
// the tensor turns out to be unused.
c_layer->ffw_gating_biases.set_scale(1.0f);
c_layer->ffw_output_biases.set_scale(1.0f);
} }
if (TConfig::kSoftmaxAttnOutputBiases && if (type == LayerAttentionType::kGemma) {
type == LayerAttentionType::kGemma) { if (TConfig::kSoftmaxAttnOutputBiases) {
GEMMA_CALL_FUNC("attn_ob", attention_output_biases); GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
} else {
// Ensure initialized so we can call data_scale1, which happens even if
// the tensor turns out to be unused.
c_layer->attention_output_biases.set_scale(1.0f);
}
} }
} }
#undef GEMMA_CALL_FUNC #undef GEMMA_CALL_FUNC