diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 00e98ed..b129443 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -84,32 +84,45 @@ struct Layer { (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; // 2x for (gelu gating vector, gated vector) static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr bool kFFBiases = TConfig::kFFBiases; + static constexpr size_t kAOBiasDim = + TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; - std::array attn_vec_einsum_w; - std::array qkv_einsum_w; - std::array gating_einsum_w; - std::array linear_w; - std::array pre_attention_norm_scale; - std::array pre_ffw_norm_scale; - // These fields are only used by Griffin, and do not affect loading of the - // model as it is done per-member. - // TODO(veluca): pull weights that are never used at the same time into a - // union or otherwise reduce the memory usage. - std::array ffw_gating_biases; - std::array ffw_output_biases; - std::array attention_output_biases; + template + using ArrayT = std::array; - std::array griffin_linear_y_w; - std::array griffin_linear_y_biases; - std::array griffin_linear_x_w; - std::array griffin_linear_x_biases; - std::array griffin_linear_out_w; - std::array griffin_linear_out_biases; - std::array griffin_conv_biases; - std::array griffin_gate_w; - std::array griffin_gate_biases; - std::array griffin_a; - std::array griffin_conv_w; + union { + struct { + ArrayT attn_vec_einsum_w; + ArrayT qkv_einsum_w; + ArrayT attention_output_biases; + }; + + struct { + ArrayT linear_x_w; + ArrayT linear_x_biases; + ArrayT linear_y_w; + ArrayT linear_y_biases; + ArrayT linear_out_w; + ArrayT linear_out_biases; + ArrayT conv_w; + ArrayT conv_biases; + ArrayT gate_w; + ArrayT gate_biases; + ArrayT a; + } griffin; + }; + + ArrayT gating_einsum_w; + ArrayT linear_w; + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + + ArrayT ffw_gating_biases; + ArrayT ffw_output_biases; }; float ScaleWeights(float* data, size_t len) { @@ -233,21 +246,21 @@ hwy::AlignedFreeUniquePtr LoadWeights( SCALE_WEIGHTS(attn_vec_einsum_w); SCALE_WEIGHTS(qkv_einsum_w); } else { - READ_WEIGHTS(griffin_linear_x_w); - READ_WEIGHTS(griffin_linear_x_biases); - READ_WEIGHTS(griffin_linear_y_w); - READ_WEIGHTS(griffin_linear_y_biases); - READ_WEIGHTS(griffin_linear_out_w); - READ_WEIGHTS(griffin_linear_out_biases); - READ_WEIGHTS(griffin_conv_w); - READ_WEIGHTS(griffin_conv_biases); - READ_WEIGHTS(griffin_gate_w); - READ_WEIGHTS(griffin_gate_biases); - READ_WEIGHTS(griffin_a); - SCALE_WEIGHTS(griffin_linear_x_w); - SCALE_WEIGHTS(griffin_linear_y_w); - SCALE_WEIGHTS(griffin_linear_out_w); - SCALE_WEIGHTS(griffin_gate_w); + READ_WEIGHTS(griffin.linear_x_w); + READ_WEIGHTS(griffin.linear_x_biases); + READ_WEIGHTS(griffin.linear_y_w); + READ_WEIGHTS(griffin.linear_y_biases); + READ_WEIGHTS(griffin.linear_out_w); + READ_WEIGHTS(griffin.linear_out_biases); + READ_WEIGHTS(griffin.conv_w); + READ_WEIGHTS(griffin.conv_biases); + READ_WEIGHTS(griffin.gate_w); + READ_WEIGHTS(griffin.gate_biases); + READ_WEIGHTS(griffin.a); + SCALE_WEIGHTS(griffin.linear_x_w); + SCALE_WEIGHTS(griffin.linear_y_w); + SCALE_WEIGHTS(griffin.linear_out_w); + SCALE_WEIGHTS(griffin.gate_w); } READ_WEIGHTS(gating_einsum_w); READ_WEIGHTS(linear_w); @@ -287,33 +300,54 @@ struct CompressedLayer { using TLayer = gcpp::Layer; using WeightT = typename TConfig::WeightT; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static constexpr size_t kHeads = TLayer::kHeads; + static constexpr size_t kKVHeads = TLayer::kKVHeads; + static constexpr size_t kModelDim = TLayer::kModelDim; + static constexpr size_t kQKVDim = TLayer::kQKVDim; + static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim; + static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize; + static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize; + static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize; + static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth; + static constexpr bool kFFBiases = TLayer::kFFBiases; + static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim; + static constexpr size_t kGriffinDim = TLayer::kGriffinDim; // Compressed Parameters - // We don't yet have an RMSNorm that accepts all WeightT. - CompressedArray pre_attention_norm_scale; - CompressedArray pre_ffw_norm_scale; - CompressedArray ffw_gating_biases; - CompressedArray ffw_output_biases; - CompressedArray attention_output_biases; - CompressedArray gating_einsum_w; - CompressedArray linear_w; - CompressedArray qkv_einsum_w; - CompressedArray attn_vec_einsum_w; - CompressedArray griffin_linear_y_w; - CompressedArray griffin_linear_x_w; - CompressedArray griffin_linear_out_w; - CompressedArray - griffin_gate_w; - CompressedArray griffin_a; - CompressedArray griffin_linear_y_biases; - CompressedArray griffin_linear_x_biases; - CompressedArray griffin_linear_out_biases; - CompressedArray griffin_conv_biases; - CompressedArray griffin_gate_biases; - CompressedArray griffin_conv_w; + template + using ArrayT = CompressedArray; + + union { + struct { + ArrayT attn_vec_einsum_w; + ArrayT qkv_einsum_w; + ArrayT attention_output_biases; + }; + + struct { + ArrayT linear_x_w; + ArrayT linear_x_biases; + ArrayT linear_y_w; + ArrayT linear_y_biases; + ArrayT linear_out_w; + ArrayT linear_out_biases; + ArrayT conv_w; + ArrayT conv_biases; + ArrayT gate_w; + ArrayT gate_biases; + ArrayT a; + } griffin; + }; + + ArrayT gating_einsum_w; + ArrayT linear_w; + // We don't yet have an RMSNorm that accepts all WeightT. + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + + ArrayT ffw_gating_biases; + ArrayT ffw_output_biases; }; // Array instead of single large allocation for parallel mem init. Split out @@ -387,10 +421,12 @@ struct Activations { std::array logits; // Griffin layer internal activations - std::array griffin_x; - std::array griffin_y; - std::array griffin_gate_x; - std::array griffin_multiplier; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; + std::array griffin_x; + std::array griffin_y; + std::array griffin_gate_x; + std::array griffin_multiplier; }; // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we @@ -541,10 +577,10 @@ HWY_NOINLINE void GriffinRecurrent( float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; TwoMatVecAdd( - layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0, + layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, activations.pre_att_rms_out.data() + batch_offset, - /*add0=*/layer_weights->griffin_linear_x_biases.data(), - /*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x, + /*add0=*/layer_weights->griffin.linear_x_biases.data(), + /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, /*out1=*/y, pool); Gelu(y, kModelDim); @@ -564,13 +600,13 @@ HWY_NOINLINE void GriffinRecurrent( } for (size_t i = 0; i < kModelDim; i += Lanes(df)) { auto xv = hn::Load(df, x + i); - auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i); + auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i); auto accum1 = hn::Zero(df); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); for (size_t l = 0; 2 * l < kConv1dWidth; l++) { - auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() + + auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + (kConv1dWidth - 1 - 2 * l) * kModelDim + i); - auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() + + auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + (kConv1dWidth - 2 - 2 * l) * kModelDim + i); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); @@ -591,10 +627,10 @@ HWY_NOINLINE void GriffinRecurrent( constexpr size_t kMatrixSize = kHeadDim * kHeadDim; size_t head_offset = head * kHeadDim; TwoOfsMatVecAddLoop( - layer_weights->griffin_gate_w, kMatrixSize * head, + layer_weights->griffin.gate_w, kMatrixSize * head, kMatrixSize * (kHeads + head), x + head_offset, - /*add0=*/layer_weights->griffin_gate_biases.data() + head_offset, - /*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim + + /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + head_offset, /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); Sigmoid(gate_x + head_offset, kHeadDim); @@ -602,7 +638,7 @@ HWY_NOINLINE void GriffinRecurrent( const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) HWY_ATTR { return hn::Mul(x, gate_x); }; hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin_a.data() + head_offset, fn_mul); + layer_weights->griffin.a.data() + head_offset, fn_mul); hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, fn_mul); // RNN scan @@ -630,8 +666,8 @@ HWY_NOINLINE void GriffinRecurrent( // Final linear layer. float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; MatVecAdd( - layer_weights->griffin_linear_out_w, 0, x, - layer_weights->griffin_linear_out_biases.data(), out_ptr, pool); + layer_weights->griffin.linear_out_w, 0, x, + layer_weights->griffin.linear_out_biases.data(), out_ptr, pool); } template @@ -1238,17 +1274,17 @@ void ForEachTensor(const Weights* weights, CALL_FUNC("qkv_ein", qkv_einsum_w); CALL_FUNC("att_ein", attn_vec_einsum_w); } else { - CALL_FUNC("gr_lin_x_w", griffin_linear_x_w); - CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases); - CALL_FUNC("gr_lin_y_w", griffin_linear_y_w); - CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases); - CALL_FUNC("gr_lin_out_w", griffin_linear_out_w); - CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases); - CALL_FUNC("gr_conv_w", griffin_conv_w); - CALL_FUNC("gr_conv_b", griffin_conv_biases); - CALL_FUNC("gr_gate_w", griffin_gate_w); - CALL_FUNC("gr_gate_b", griffin_gate_biases); - CALL_FUNC("gr_a", griffin_a); + CALL_FUNC("gr_lin_x_w", griffin.linear_x_w); + CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases); + CALL_FUNC("gr_lin_y_w", griffin.linear_y_w); + CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases); + CALL_FUNC("gr_lin_out_w", griffin.linear_out_w); + CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases); + CALL_FUNC("gr_conv_w", griffin.conv_w); + CALL_FUNC("gr_conv_b", griffin.conv_biases); + CALL_FUNC("gr_gate_w", griffin.gate_w); + CALL_FUNC("gr_gate_b", griffin.gate_biases); + CALL_FUNC("gr_a", griffin.a); } CALL_FUNC("pre_att_ns", pre_attention_norm_scale); @@ -1298,10 +1334,10 @@ hwy::AlignedFreeUniquePtr LoadCompressedWeights( layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); } else { - layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]); - layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]); - layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]); - layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]); } layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]); layer_weights->linear_w.set_scale(scales[scale_pos++]);