Rename the fields of Griffin weights

This commit is contained in:
RangerUFO 2024-04-10 21:04:31 +08:00
parent 4e960d67f6
commit e541707caa
1 changed files with 66 additions and 66 deletions

View File

@ -102,18 +102,18 @@ struct Layer {
}; };
struct { struct {
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_x_w; ArrayT<float, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> griffin_linear_x_biases; ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_y_w; ArrayT<float, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> griffin_linear_y_biases; ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_out_w; ArrayT<float, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> griffin_linear_out_biases; ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, kConv1dWidth * kGriffinDim> griffin_conv_w; ArrayT<float, kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> griffin_conv_biases; ArrayT<float, kGriffinDim> conv_biases;
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> griffin_gate_w; ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> griffin_gate_biases; ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> griffin_a; ArrayT<float, kGriffinDim> a;
}; } griffin;
}; };
ArrayT<float, kGatingEinsumWSize> gating_einsum_w; ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
@ -246,21 +246,21 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
SCALE_WEIGHTS(attn_vec_einsum_w); SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w); SCALE_WEIGHTS(qkv_einsum_w);
} else { } else {
READ_WEIGHTS(griffin_linear_x_w); READ_WEIGHTS(griffin.linear_x_w);
READ_WEIGHTS(griffin_linear_x_biases); READ_WEIGHTS(griffin.linear_x_biases);
READ_WEIGHTS(griffin_linear_y_w); READ_WEIGHTS(griffin.linear_y_w);
READ_WEIGHTS(griffin_linear_y_biases); READ_WEIGHTS(griffin.linear_y_biases);
READ_WEIGHTS(griffin_linear_out_w); READ_WEIGHTS(griffin.linear_out_w);
READ_WEIGHTS(griffin_linear_out_biases); READ_WEIGHTS(griffin.linear_out_biases);
READ_WEIGHTS(griffin_conv_w); READ_WEIGHTS(griffin.conv_w);
READ_WEIGHTS(griffin_conv_biases); READ_WEIGHTS(griffin.conv_biases);
READ_WEIGHTS(griffin_gate_w); READ_WEIGHTS(griffin.gate_w);
READ_WEIGHTS(griffin_gate_biases); READ_WEIGHTS(griffin.gate_biases);
READ_WEIGHTS(griffin_a); READ_WEIGHTS(griffin.a);
SCALE_WEIGHTS(griffin_linear_x_w); SCALE_WEIGHTS(griffin.linear_x_w);
SCALE_WEIGHTS(griffin_linear_y_w); SCALE_WEIGHTS(griffin.linear_y_w);
SCALE_WEIGHTS(griffin_linear_out_w); SCALE_WEIGHTS(griffin.linear_out_w);
SCALE_WEIGHTS(griffin_gate_w); SCALE_WEIGHTS(griffin.gate_w);
} }
READ_WEIGHTS(gating_einsum_w); READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w); READ_WEIGHTS(linear_w);
@ -326,18 +326,18 @@ struct CompressedLayer {
}; };
struct { struct {
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_x_w; ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> griffin_linear_x_biases; ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_y_w; ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> griffin_linear_y_biases; ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_out_w; ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> griffin_linear_out_biases; ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> griffin_conv_w; ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> griffin_conv_biases; ArrayT<float, kGriffinDim> conv_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> griffin_gate_w; ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> griffin_gate_biases; ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> griffin_a; ArrayT<float, kGriffinDim> a;
}; } griffin;
}; };
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w; ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
@ -577,10 +577,10 @@ HWY_NOINLINE void GriffinRecurrent(
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
TwoMatVecAdd<true, kModelDim, kModelDim>( TwoMatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0, layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
activations.pre_att_rms_out.data() + batch_offset, activations.pre_att_rms_out.data() + batch_offset,
/*add0=*/layer_weights->griffin_linear_x_biases.data(), /*add0=*/layer_weights->griffin.linear_x_biases.data(),
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x, /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool); /*out1=*/y, pool);
Gelu(y, kModelDim); Gelu(y, kModelDim);
@ -600,13 +600,13 @@ HWY_NOINLINE void GriffinRecurrent(
} }
for (size_t i = 0; i < kModelDim; i += Lanes(df)) { for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
auto xv = hn::Load(df, x + i); auto xv = hn::Load(df, x + i);
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i); auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
auto accum1 = hn::Zero(df); auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) { for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() + auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i); (kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() + auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i); (kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
@ -627,10 +627,10 @@ HWY_NOINLINE void GriffinRecurrent(
constexpr size_t kMatrixSize = kHeadDim * kHeadDim; constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
size_t head_offset = head * kHeadDim; size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>( TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
layer_weights->griffin_gate_w, kMatrixSize * head, layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset, kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset, /*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim + /*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
head_offset, head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim); Sigmoid(gate_x + head_offset, kHeadDim);
@ -638,7 +638,7 @@ HWY_NOINLINE void GriffinRecurrent(
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x) const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); }; HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim, hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin_a.data() + head_offset, fn_mul); layer_weights->griffin.a.data() + head_offset, fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul); fn_mul);
// RNN scan // RNN scan
@ -666,8 +666,8 @@ HWY_NOINLINE void GriffinRecurrent(
// Final linear layer. // Final linear layer.
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<true, kModelDim, kModelDim>( MatVecAdd<true, kModelDim, kModelDim>(
layer_weights->griffin_linear_out_w, 0, x, layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool); layer_weights->griffin.linear_out_biases.data(), out_ptr, pool);
} }
template <size_t kBatchSize, typename LayerT, class TConfig> template <size_t kBatchSize, typename LayerT, class TConfig>
@ -1274,17 +1274,17 @@ void ForEachTensor(const Weights<TConfig>* weights,
CALL_FUNC("qkv_ein", qkv_einsum_w); CALL_FUNC("qkv_ein", qkv_einsum_w);
CALL_FUNC("att_ein", attn_vec_einsum_w); CALL_FUNC("att_ein", attn_vec_einsum_w);
} else { } else {
CALL_FUNC("gr_lin_x_w", griffin_linear_x_w); CALL_FUNC("gr_lin_x_w", griffin.linear_x_w);
CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases); CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases);
CALL_FUNC("gr_lin_y_w", griffin_linear_y_w); CALL_FUNC("gr_lin_y_w", griffin.linear_y_w);
CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases); CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases);
CALL_FUNC("gr_lin_out_w", griffin_linear_out_w); CALL_FUNC("gr_lin_out_w", griffin.linear_out_w);
CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases); CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases);
CALL_FUNC("gr_conv_w", griffin_conv_w); CALL_FUNC("gr_conv_w", griffin.conv_w);
CALL_FUNC("gr_conv_b", griffin_conv_biases); CALL_FUNC("gr_conv_b", griffin.conv_biases);
CALL_FUNC("gr_gate_w", griffin_gate_w); CALL_FUNC("gr_gate_w", griffin.gate_w);
CALL_FUNC("gr_gate_b", griffin_gate_biases); CALL_FUNC("gr_gate_b", griffin.gate_biases);
CALL_FUNC("gr_a", griffin_a); CALL_FUNC("gr_a", griffin.a);
} }
CALL_FUNC("pre_att_ns", pre_attention_norm_scale); CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
@ -1334,10 +1334,10 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
} else { } else {
layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]); layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]); layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]); layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]);
layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]); layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]);
} }
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]); layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->linear_w.set_scale(scales[scale_pos++]); layer_weights->linear_w.set_scale(scales[scale_pos++]);