mirror of https://github.com/google/gemma.cpp.git
Fix msan uninitialized scale in optimize_test
PiperOrigin-RevId: 654817460
This commit is contained in:
parent
74a6dc8f33
commit
33334ad454
|
|
@ -353,9 +353,11 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
activations.att_post2.Batch(interleaved_idx);
|
||||
// Head 0 (and potentially biases) -> layer_out.
|
||||
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||
layer_weights->attention_output_biases.data_scale1(),
|
||||
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
|
||||
const float* bias =
|
||||
kAdd ? layer_weights->attention_output_biases.data_scale1() : nullptr;
|
||||
MatVecT<kAdd, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out, bias,
|
||||
activations.even_odd.All(), layer_out, pool);
|
||||
// Head 1 and following are added to layer_out.
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
|
|
@ -425,8 +427,14 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
auto C1 = activations.C1.All();
|
||||
auto C2 = activations.C2.All();
|
||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||
const auto bias1 = layer_weights->ffw_gating_biases.data_scale1();
|
||||
const auto bias2 = bias1 + kFFHiddenDim;
|
||||
const float* bias1 = nullptr;
|
||||
const float* bias2 = nullptr;
|
||||
const float* output_bias = nullptr;
|
||||
if constexpr (kAddBias) {
|
||||
bias1 = layer_weights->ffw_gating_biases.data_scale1();
|
||||
bias2 = bias1 + kFFHiddenDim;
|
||||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||
}
|
||||
|
||||
// Will go through GELU.
|
||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B1, scale,
|
||||
|
|
@ -442,7 +450,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
||||
num_interleaved, C1, layer_weights->linear_w.data(),
|
||||
layer_weights->linear_w.scale(), activations.ffw_out.All(),
|
||||
layer_weights->ffw_output_biases.data_scale1(), pool);
|
||||
output_bias, pool);
|
||||
}
|
||||
|
||||
// TODO: pass Activations.x instead of Activations.
|
||||
|
|
@ -477,9 +485,10 @@ HWY_NOINLINE void ResidualConnection(
|
|||
}
|
||||
|
||||
template <class TConfig, typename WeightT, typename InOutT>
|
||||
void PostNorm(size_t num_interleaved, const WeightT* weights, InOutT* inout) {
|
||||
void PostNorm(size_t num_interleaved, const WeightT& weights, InOutT* inout) {
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
RMSNormInplaceBatched(num_interleaved, weights, inout, TConfig::kModelDim);
|
||||
RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout,
|
||||
TConfig::kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -501,8 +510,7 @@ HWY_NOINLINE void TransformerLayer(
|
|||
Attention<TConfig>(type, pos, num_tokens, num_queries, layer_of_type,
|
||||
activations, layer_weights, kv_caches, pool);
|
||||
|
||||
PostNorm<TConfig>(num_interleaved,
|
||||
layer_weights->post_attention_norm_scale.data_scale1(),
|
||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
|
||||
activations.att_post2.All());
|
||||
|
||||
ResidualConnection<TConfig>(num_interleaved, activations.att_post2.All(),
|
||||
|
|
@ -515,8 +523,7 @@ HWY_NOINLINE void TransformerLayer(
|
|||
|
||||
FFW<TConfig>(activations, num_interleaved, layer_weights, pool);
|
||||
|
||||
PostNorm<TConfig>(num_interleaved,
|
||||
layer_weights->post_ffw_norm_scale.data_scale1(),
|
||||
PostNorm<TConfig>(num_interleaved, layer_weights->post_ffw_norm_scale,
|
||||
activations.ffw_out.All());
|
||||
|
||||
ResidualConnection<TConfig>(num_interleaved, activations.ffw_out.All(),
|
||||
|
|
|
|||
|
|
@ -230,6 +230,11 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
|||
//
|
||||
// This avoids repeating the list of tensors between loading and compressing,
|
||||
// while also avoiding dependency on raw_weights.h.
|
||||
//
|
||||
// This only calls Func for tensors that TConfig requests/specifies, which means
|
||||
// scale() is uninitialized for the other tensors, so their data_scale1() must
|
||||
// not be called. (In other words, if the config doesn't specify a tensor, it
|
||||
// shouldn't be used.)
|
||||
template <class TConfig, class RawLayer = void, class RawWeightsPtr, class Func>
|
||||
void ForEachTensor(RawWeightsPtr raw_weights,
|
||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||
|
|
@ -269,33 +274,19 @@ void ForEachTensor(RawWeightsPtr raw_weights,
|
|||
}
|
||||
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||
|
||||
// For conditionally-included tensors, the else branch must ensure their
|
||||
// scale is initialized, because wrapper functions call data_scale1 even if
|
||||
// the tensor turns out to be unused. If unused, the arrays are zero-length
|
||||
// and data() returns a non-null but unusable pointer.
|
||||
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
|
||||
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
|
||||
} else {
|
||||
c_layer->post_attention_norm_scale.set_scale(1.0f);
|
||||
c_layer->post_ffw_norm_scale.set_scale(1.0f);
|
||||
}
|
||||
|
||||
if (TConfig::kFFBiases) {
|
||||
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
||||
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
|
||||
} else {
|
||||
c_layer->ffw_gating_biases.set_scale(1.0f);
|
||||
c_layer->ffw_output_biases.set_scale(1.0f);
|
||||
}
|
||||
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
if (TConfig::kSoftmaxAttnOutputBiases) {
|
||||
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
|
||||
} else {
|
||||
c_layer->attention_output_biases.set_scale(1.0f);
|
||||
}
|
||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||
type == LayerAttentionType::kGemma) {
|
||||
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
|
||||
}
|
||||
}
|
||||
#undef GEMMA_CALL_FUNC
|
||||
|
|
|
|||
Loading…
Reference in New Issue