From 12707ade80b2b3cecfd57ce9acce900005a48d5d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 6 Jun 2024 10:59:46 -0700 Subject: [PATCH] Toward only using compressed weights: CompressedLayer should all be f32 when weights are f32. PiperOrigin-RevId: 640954519 --- gemma/weights.h | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/gemma/weights.h b/gemma/weights.h index 2346d67..d6ebd62 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -129,12 +129,21 @@ using WeightsF = Weights; // ---------------------------------------------------------------------------- // Compressed +// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do +// not yet support smaller compressed types, or require at least bf16. When +// weights are f32, we also want such tensors to be f32. +template +using WeightF32OrBF16T = + hwy::If(), float, + hwy::bfloat16_t>; + template struct CompressedLayer { // No ctor/dtor, allocated via AllocateAligned. using TLayer = gcpp::LayerF; using WeightT = typename TConfig::WeightT; + using WeightF32OrBF16 = WeightF32OrBF16T; static constexpr size_t kHeads = TLayer::kHeads; static constexpr size_t kKVHeads = TLayer::kKVHeads; @@ -180,11 +189,11 @@ struct CompressedLayer { ArrayT gating_einsum_w; ArrayT linear_w; // We don't yet have an RMSNorm that accepts all WeightT. - ArrayT pre_attention_norm_scale; - ArrayT pre_ffw_norm_scale; - ArrayT + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; + ArrayT post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases;