mirror of https://github.com/google/gemma.cpp.git
Toward only using compressed weights:
CompressedLayer should all be f32 when weights are f32. PiperOrigin-RevId: 640954519
This commit is contained in:
parent
6c0be20fa6
commit
12707ade80
|
|
@ -129,12 +129,21 @@ using WeightsF = Weights<float, TConfig>;
|
|||
// ----------------------------------------------------------------------------
|
||||
// Compressed
|
||||
|
||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that do
|
||||
// not yet support smaller compressed types, or require at least bf16. When
|
||||
// weights are f32, we also want such tensors to be f32.
|
||||
template <class TConfig>
|
||||
using WeightF32OrBF16T =
|
||||
hwy::If<hwy::IsSame<typename TConfig::WeightT, float>(), float,
|
||||
hwy::bfloat16_t>;
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedLayer {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
using TLayer = gcpp::LayerF<TConfig>;
|
||||
using WeightT = typename TConfig::WeightT;
|
||||
using WeightF32OrBF16 = WeightF32OrBF16T<TConfig>;
|
||||
|
||||
static constexpr size_t kHeads = TLayer::kHeads;
|
||||
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
||||
|
|
@ -180,11 +189,11 @@ struct CompressedLayer {
|
|||
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0>
|
||||
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
|
||||
post_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
|
|
|
|||
Loading…
Reference in New Issue