From e541707caa0b8d7c230b00e5768cd2d470a86863 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 10 Apr 2024 21:04:31 +0800 Subject: [PATCH] Rename the fields of Griffin weights --- gemma/gemma.cc | 132 ++++++++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 66 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index f2c6108..b129443 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -102,18 +102,18 @@ struct Layer { }; struct { - ArrayT griffin_linear_x_w; - ArrayT griffin_linear_x_biases; - ArrayT griffin_linear_y_w; - ArrayT griffin_linear_y_biases; - ArrayT griffin_linear_out_w; - ArrayT griffin_linear_out_biases; - ArrayT griffin_conv_w; - ArrayT griffin_conv_biases; - ArrayT griffin_gate_w; - ArrayT griffin_gate_biases; - ArrayT griffin_a; - }; + ArrayT linear_x_w; + ArrayT linear_x_biases; + ArrayT linear_y_w; + ArrayT linear_y_biases; + ArrayT linear_out_w; + ArrayT linear_out_biases; + ArrayT conv_w; + ArrayT conv_biases; + ArrayT gate_w; + ArrayT gate_biases; + ArrayT a; + } griffin; }; ArrayT gating_einsum_w; @@ -246,21 +246,21 @@ hwy::AlignedFreeUniquePtr LoadWeights( SCALE_WEIGHTS(attn_vec_einsum_w); SCALE_WEIGHTS(qkv_einsum_w); } else { - READ_WEIGHTS(griffin_linear_x_w); - READ_WEIGHTS(griffin_linear_x_biases); - READ_WEIGHTS(griffin_linear_y_w); - READ_WEIGHTS(griffin_linear_y_biases); - READ_WEIGHTS(griffin_linear_out_w); - READ_WEIGHTS(griffin_linear_out_biases); - READ_WEIGHTS(griffin_conv_w); - READ_WEIGHTS(griffin_conv_biases); - READ_WEIGHTS(griffin_gate_w); - READ_WEIGHTS(griffin_gate_biases); - READ_WEIGHTS(griffin_a); - SCALE_WEIGHTS(griffin_linear_x_w); - SCALE_WEIGHTS(griffin_linear_y_w); - SCALE_WEIGHTS(griffin_linear_out_w); - SCALE_WEIGHTS(griffin_gate_w); + READ_WEIGHTS(griffin.linear_x_w); + READ_WEIGHTS(griffin.linear_x_biases); + READ_WEIGHTS(griffin.linear_y_w); + READ_WEIGHTS(griffin.linear_y_biases); + READ_WEIGHTS(griffin.linear_out_w); + READ_WEIGHTS(griffin.linear_out_biases); + READ_WEIGHTS(griffin.conv_w); + READ_WEIGHTS(griffin.conv_biases); + READ_WEIGHTS(griffin.gate_w); + READ_WEIGHTS(griffin.gate_biases); + READ_WEIGHTS(griffin.a); + SCALE_WEIGHTS(griffin.linear_x_w); + SCALE_WEIGHTS(griffin.linear_y_w); + SCALE_WEIGHTS(griffin.linear_out_w); + SCALE_WEIGHTS(griffin.gate_w); } READ_WEIGHTS(gating_einsum_w); READ_WEIGHTS(linear_w); @@ -326,18 +326,18 @@ struct CompressedLayer { }; struct { - ArrayT griffin_linear_x_w; - ArrayT griffin_linear_x_biases; - ArrayT griffin_linear_y_w; - ArrayT griffin_linear_y_biases; - ArrayT griffin_linear_out_w; - ArrayT griffin_linear_out_biases; - ArrayT griffin_conv_w; - ArrayT griffin_conv_biases; - ArrayT griffin_gate_w; - ArrayT griffin_gate_biases; - ArrayT griffin_a; - }; + ArrayT linear_x_w; + ArrayT linear_x_biases; + ArrayT linear_y_w; + ArrayT linear_y_biases; + ArrayT linear_out_w; + ArrayT linear_out_biases; + ArrayT conv_w; + ArrayT conv_biases; + ArrayT gate_w; + ArrayT gate_biases; + ArrayT a; + } griffin; }; ArrayT gating_einsum_w; @@ -577,10 +577,10 @@ HWY_NOINLINE void GriffinRecurrent( float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; TwoMatVecAdd( - layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0, + layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, activations.pre_att_rms_out.data() + batch_offset, - /*add0=*/layer_weights->griffin_linear_x_biases.data(), - /*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x, + /*add0=*/layer_weights->griffin.linear_x_biases.data(), + /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, /*out1=*/y, pool); Gelu(y, kModelDim); @@ -600,13 +600,13 @@ HWY_NOINLINE void GriffinRecurrent( } for (size_t i = 0; i < kModelDim; i += Lanes(df)) { auto xv = hn::Load(df, x + i); - auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i); + auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i); auto accum1 = hn::Zero(df); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); for (size_t l = 0; 2 * l < kConv1dWidth; l++) { - auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() + + auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + (kConv1dWidth - 1 - 2 * l) * kModelDim + i); - auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() + + auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + (kConv1dWidth - 2 - 2 * l) * kModelDim + i); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); @@ -627,10 +627,10 @@ HWY_NOINLINE void GriffinRecurrent( constexpr size_t kMatrixSize = kHeadDim * kHeadDim; size_t head_offset = head * kHeadDim; TwoOfsMatVecAddLoop( - layer_weights->griffin_gate_w, kMatrixSize * head, + layer_weights->griffin.gate_w, kMatrixSize * head, kMatrixSize * (kHeads + head), x + head_offset, - /*add0=*/layer_weights->griffin_gate_biases.data() + head_offset, - /*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim + + /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim + head_offset, /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); Sigmoid(gate_x + head_offset, kHeadDim); @@ -638,7 +638,7 @@ HWY_NOINLINE void GriffinRecurrent( const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) HWY_ATTR { return hn::Mul(x, gate_x); }; hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin_a.data() + head_offset, fn_mul); + layer_weights->griffin.a.data() + head_offset, fn_mul); hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, fn_mul); // RNN scan @@ -666,8 +666,8 @@ HWY_NOINLINE void GriffinRecurrent( // Final linear layer. float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; MatVecAdd( - layer_weights->griffin_linear_out_w, 0, x, - layer_weights->griffin_linear_out_biases.data(), out_ptr, pool); + layer_weights->griffin.linear_out_w, 0, x, + layer_weights->griffin.linear_out_biases.data(), out_ptr, pool); } template @@ -1274,17 +1274,17 @@ void ForEachTensor(const Weights* weights, CALL_FUNC("qkv_ein", qkv_einsum_w); CALL_FUNC("att_ein", attn_vec_einsum_w); } else { - CALL_FUNC("gr_lin_x_w", griffin_linear_x_w); - CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases); - CALL_FUNC("gr_lin_y_w", griffin_linear_y_w); - CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases); - CALL_FUNC("gr_lin_out_w", griffin_linear_out_w); - CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases); - CALL_FUNC("gr_conv_w", griffin_conv_w); - CALL_FUNC("gr_conv_b", griffin_conv_biases); - CALL_FUNC("gr_gate_w", griffin_gate_w); - CALL_FUNC("gr_gate_b", griffin_gate_biases); - CALL_FUNC("gr_a", griffin_a); + CALL_FUNC("gr_lin_x_w", griffin.linear_x_w); + CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases); + CALL_FUNC("gr_lin_y_w", griffin.linear_y_w); + CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases); + CALL_FUNC("gr_lin_out_w", griffin.linear_out_w); + CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases); + CALL_FUNC("gr_conv_w", griffin.conv_w); + CALL_FUNC("gr_conv_b", griffin.conv_biases); + CALL_FUNC("gr_gate_w", griffin.gate_w); + CALL_FUNC("gr_gate_b", griffin.gate_biases); + CALL_FUNC("gr_a", griffin.a); } CALL_FUNC("pre_att_ns", pre_attention_norm_scale); @@ -1334,10 +1334,10 @@ hwy::AlignedFreeUniquePtr LoadCompressedWeights( layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); } else { - layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]); - layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]); - layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]); - layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]); + layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]); } layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]); layer_weights->linear_w.set_scale(scales[scale_pos++]);