diff --git a/gemma/weights.h b/gemma/weights.h index 2346d67..d6ebd62 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -129,12 +129,21 @@ using WeightsF = Weights; // ---------------------------------------------------------------------------- // Compressed +// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do +// not yet support smaller compressed types, or require at least bf16. When +// weights are f32, we also want such tensors to be f32. +template +using WeightF32OrBF16T = + hwy::If(), float, + hwy::bfloat16_t>; + template struct CompressedLayer { // No ctor/dtor, allocated via AllocateAligned. using TLayer = gcpp::LayerF; using WeightT = typename TConfig::WeightT; + using WeightF32OrBF16 = WeightF32OrBF16T; static constexpr size_t kHeads = TLayer::kHeads; static constexpr size_t kKVHeads = TLayer::kKVHeads; @@ -180,11 +189,11 @@ struct CompressedLayer { ArrayT gating_einsum_w; ArrayT linear_w; // We don't yet have an RMSNorm that accepts all WeightT. - ArrayT pre_attention_norm_scale; - ArrayT pre_ffw_norm_scale; - ArrayT + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; + ArrayT post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases;