diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 00e98ed..5bb11d9 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -84,32 +84,45 @@ struct Layer { (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; // 2x for (gelu gating vector, gated vector) static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr bool kFFBiases = TConfig::kFFBiases; + static constexpr size_t kAOBiaseDim = + TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; - std::array attn_vec_einsum_w; - std::array qkv_einsum_w; - std::array gating_einsum_w; - std::array linear_w; - std::array pre_attention_norm_scale; - std::array pre_ffw_norm_scale; - // These fields are only used by Griffin, and do not affect loading of the - // model as it is done per-member. - // TODO(veluca): pull weights that are never used at the same time into a - // union or otherwise reduce the memory usage. - std::array ffw_gating_biases; - std::array ffw_output_biases; - std::array attention_output_biases; + template + using ArrayT = std::array; - std::array griffin_linear_y_w; - std::array griffin_linear_y_biases; - std::array griffin_linear_x_w; - std::array griffin_linear_x_biases; - std::array griffin_linear_out_w; - std::array griffin_linear_out_biases; - std::array griffin_conv_biases; - std::array griffin_gate_w; - std::array griffin_gate_biases; - std::array griffin_a; - std::array griffin_conv_w; + union { + struct { + ArrayT attn_vec_einsum_w; + ArrayT qkv_einsum_w; + ArrayT attention_output_biases; + }; + + struct { + ArrayT griffin_linear_x_w; + ArrayT griffin_linear_x_biases; + ArrayT griffin_linear_y_w; + ArrayT griffin_linear_y_biases; + ArrayT griffin_linear_out_w; + ArrayT griffin_linear_out_biases; + ArrayT griffin_conv_w; + ArrayT griffin_conv_biases; + ArrayT griffin_gate_w; + ArrayT griffin_gate_biases; + ArrayT griffin_a; + }; + }; + + ArrayT gating_einsum_w; + ArrayT linear_w; + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + + ArrayT ffw_gating_biases; + ArrayT ffw_output_biases; }; float ScaleWeights(float* data, size_t len) { @@ -287,33 +300,54 @@ struct CompressedLayer { using TLayer = gcpp::Layer; using WeightT = typename TConfig::WeightT; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static constexpr size_t kHeads = TLayer::kHeads; + static constexpr size_t kKVHeads = TLayer::kKVHeads; + static constexpr size_t kModelDim = TLayer::kModelDim; + static constexpr size_t kQKVDim = TLayer::kQKVDim; + static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim; + static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize; + static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize; + static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize; + static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth; + static constexpr bool kFFBiases = TLayer::kFFBiases; + static constexpr size_t kAOBiaseDim = TLayer::kAOBiaseDim; + static constexpr size_t kGriffinDim = TLayer::kGriffinDim; // Compressed Parameters - // We don't yet have an RMSNorm that accepts all WeightT. - CompressedArray pre_attention_norm_scale; - CompressedArray pre_ffw_norm_scale; - CompressedArray ffw_gating_biases; - CompressedArray ffw_output_biases; - CompressedArray attention_output_biases; - CompressedArray gating_einsum_w; - CompressedArray linear_w; - CompressedArray qkv_einsum_w; - CompressedArray attn_vec_einsum_w; - CompressedArray griffin_linear_y_w; - CompressedArray griffin_linear_x_w; - CompressedArray griffin_linear_out_w; - CompressedArray - griffin_gate_w; - CompressedArray griffin_a; - CompressedArray griffin_linear_y_biases; - CompressedArray griffin_linear_x_biases; - CompressedArray griffin_linear_out_biases; - CompressedArray griffin_conv_biases; - CompressedArray griffin_gate_biases; - CompressedArray griffin_conv_w; + template + using ArrayT = CompressedArray; + + union { + struct { + ArrayT attn_vec_einsum_w; + ArrayT qkv_einsum_w; + ArrayT attention_output_biases; + }; + + struct { + ArrayT griffin_linear_x_w; + ArrayT griffin_linear_x_biases; + ArrayT griffin_linear_y_w; + ArrayT griffin_linear_y_biases; + ArrayT griffin_linear_out_w; + ArrayT griffin_linear_out_biases; + ArrayT griffin_conv_w; + ArrayT griffin_conv_biases; + ArrayT griffin_gate_w; + ArrayT griffin_gate_biases; + ArrayT griffin_a; + }; + }; + + ArrayT gating_einsum_w; + ArrayT linear_w; + // We don't yet have an RMSNorm that accepts all WeightT. + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + + ArrayT ffw_gating_biases; + ArrayT ffw_output_biases; }; // Array instead of single large allocation for parallel mem init. Split out @@ -387,10 +421,12 @@ struct Activations { std::array logits; // Griffin layer internal activations - std::array griffin_x; - std::array griffin_y; - std::array griffin_gate_x; - std::array griffin_multiplier; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; + std::array griffin_x; + std::array griffin_y; + std::array griffin_gate_x; + std::array griffin_multiplier; }; // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we