Toward only using compressed weights:

CompressedLayer should all be f32 when weights are f32.

PiperOrigin-RevId: 640954519
This commit is contained in:
Jan Wassenberg 2024-06-06 10:59:46 -07:00 committed by Copybara-Service
parent 6c0be20fa6
commit 12707ade80
1 changed files with 13 additions and 4 deletions

View File

@ -129,12 +129,21 @@ using WeightsF = Weights<float, TConfig>;
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Compressed // Compressed
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do
// not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
template <class TConfig>
using WeightF32OrBF16T =
hwy::If<hwy::IsSame<typename TConfig::WeightT, float>(), float,
hwy::bfloat16_t>;
template <class TConfig> template <class TConfig>
struct CompressedLayer { struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::LayerF<TConfig>; using TLayer = gcpp::LayerF<TConfig>;
using WeightT = typename TConfig::WeightT; using WeightT = typename TConfig::WeightT;
using WeightF32OrBF16 = WeightF32OrBF16T<TConfig>;
static constexpr size_t kHeads = TLayer::kHeads; static constexpr size_t kHeads = TLayer::kHeads;
static constexpr size_t kKVHeads = TLayer::kKVHeads; static constexpr size_t kKVHeads = TLayer::kKVHeads;
@ -180,11 +189,11 @@ struct CompressedLayer {
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w; ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w; ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all WeightT. // We don't yet have an RMSNorm that accepts all WeightT.
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale; ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale; ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0> ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
post_attention_norm_scale; post_attention_norm_scale;
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale; ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases; ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases; ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;