Fix msan uninitialized scale in optimize_test

PiperOrigin-RevId: 654817460
This commit is contained in:
Daniel Keysers 2024-07-22 10:49:49 -07:00 committed by Copybara-Service
parent 74a6dc8f33
commit 33334ad454
2 changed files with 27 additions and 29 deletions

View File

@ -353,9 +353,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
activations.att_post2.Batch(interleaved_idx);
// Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data_scale1(),
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
const float* bias =
kAdd ? layer_weights->attention_output_biases.data_scale1() : nullptr;
MatVecT<kAdd, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out, bias,
activations.even_odd.All(), layer_out, pool);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) {
@ -425,8 +427,14 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
auto C1 = activations.C1.All();
auto C2 = activations.C2.All();
constexpr bool kAddBias = TConfig::kFFBiases;
const auto bias1 = layer_weights->ffw_gating_biases.data_scale1();
const auto bias2 = bias1 + kFFHiddenDim;
const float* bias1 = nullptr;
const float* bias2 = nullptr;
const float* output_bias = nullptr;
if constexpr (kAddBias) {
bias1 = layer_weights->ffw_gating_biases.data_scale1();
bias2 = bias1 + kFFHiddenDim;
output_bias = layer_weights->ffw_output_biases.data_scale1();
}
// Will go through GELU.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B1, scale,
@ -442,7 +450,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
num_interleaved, C1, layer_weights->linear_w.data(),
layer_weights->linear_w.scale(), activations.ffw_out.All(),
layer_weights->ffw_output_biases.data_scale1(), pool);
output_bias, pool);
}
// TODO: pass Activations.x instead of Activations.
@ -477,9 +485,10 @@ HWY_NOINLINE void ResidualConnection(
}
template <class TConfig, typename WeightT, typename InOutT>
void PostNorm(size_t num_interleaved, const WeightT* weights, InOutT* inout) {
void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched(num_interleaved, weights, inout, TConfig::kModelDim);
RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout,
TConfig::kModelDim);
}
}
@ -501,8 +510,7 @@ HWY_NOINLINE void TransformerLayer(
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
activations, layer_weights, kv_caches, pool);
PostNorm<TConfig>(num_interleaved,
layer_weights->post_attention_norm_scale.data_scale1(),
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
activations.att_post2.All());
ResidualConnection<TConfig>(num_interleaved, activations.att_post2.All(),
@ -515,8 +523,7 @@ HWY_NOINLINE void TransformerLayer(
FFW<TConfig>(activations, num_interleaved, layer_weights, pool);
PostNorm<TConfig>(num_interleaved,
layer_weights->post_ffw_norm_scale.data_scale1(),
PostNorm<TConfig>(num_interleaved, layer_weights->post_ffw_norm_scale,
activations.ffw_out.All());
ResidualConnection<TConfig>(num_interleaved, activations.ffw_out.All(),

View File

@ -230,6 +230,11 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
//
// This avoids repeating the list of tensors between loading and compressing,
// while also avoiding dependency on raw_weights.h.
//
// This only calls Func for tensors that TConfig requests/specifies, which means
// scale() is uninitialized for the other tensors, so their data_scale1() must
// not be called. (In other words, if the config doesn't specify a tensor, it
// shouldn't be used.)
template <class TConfig, class RawLayer = void, class RawWeightsPtr, class Func>
void ForEachTensor(RawWeightsPtr raw_weights,
CompressedWeights<TConfig>& c_weights, Func& func) {
@ -269,33 +274,19 @@ void ForEachTensor(RawWeightsPtr raw_weights,
}
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
// For conditionally-included tensors, the else branch must ensure their
// scale is initialized, because wrapper functions call data_scale1 even if
// the tensor turns out to be unused. If unused, the arrays are zero-length
// and data() returns a non-null but unusable pointer.
if (TConfig::kPostNorm == PostNormType::Scale) {
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
} else {
c_layer->post_attention_norm_scale.set_scale(1.0f);
c_layer->post_ffw_norm_scale.set_scale(1.0f);
}
if (TConfig::kFFBiases) {
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
} else {
c_layer->ffw_gating_biases.set_scale(1.0f);
c_layer->ffw_output_biases.set_scale(1.0f);
}
if (type == LayerAttentionType::kGemma) {
if (TConfig::kSoftmaxAttnOutputBiases) {
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
} else {
c_layer->attention_output_biases.set_scale(1.0f);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
}
}
#undef GEMMA_CALL_FUNC