mirror of https://github.com/google/gemma.cpp.git
Merge pull request #142 from ufownl:refactor/data_structures
PiperOrigin-RevId: 623503486
This commit is contained in:
commit
342e998cb6
224
gemma/gemma.cc
224
gemma/gemma.cc
|
|
@ -84,32 +84,45 @@ struct Layer {
|
|||
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||
// 2x for (gelu gating vector, gated vector)
|
||||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||
|
||||
std::array<float, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
std::array<float, kQKVEinsumWSize> qkv_einsum_w;
|
||||
std::array<float, kGatingEinsumWSize> gating_einsum_w;
|
||||
std::array<float, kModelDim * kFFHiddenDim> linear_w;
|
||||
std::array<float, kModelDim> pre_attention_norm_scale;
|
||||
std::array<float, kModelDim> pre_ffw_norm_scale;
|
||||
// These fields are only used by Griffin, and do not affect loading of the
|
||||
// model as it is done per-member.
|
||||
// TODO(veluca): pull weights that are never used at the same time into a
|
||||
// union or otherwise reduce the memory usage.
|
||||
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
|
||||
std::array<float, kModelDim> ffw_output_biases;
|
||||
std::array<float, kModelDim> attention_output_biases;
|
||||
template <class T, size_t N>
|
||||
using ArrayT = std::array<T, N>;
|
||||
|
||||
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
|
||||
std::array<float, kModelDim> griffin_linear_y_biases;
|
||||
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
|
||||
std::array<float, kModelDim> griffin_linear_x_biases;
|
||||
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
|
||||
std::array<float, kModelDim> griffin_linear_out_biases;
|
||||
std::array<float, kModelDim> griffin_conv_biases;
|
||||
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
|
||||
std::array<float, kModelDim * 2> griffin_gate_biases;
|
||||
std::array<float, kModelDim> griffin_a;
|
||||
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
|
||||
union {
|
||||
struct {
|
||||
ArrayT<float, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
ArrayT<float, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<float, kAOBiasDim> attention_output_biases;
|
||||
};
|
||||
|
||||
struct {
|
||||
ArrayT<float, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
ArrayT<float, kGriffinDim> linear_x_biases;
|
||||
ArrayT<float, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||
ArrayT<float, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||
ArrayT<float, kConv1dWidth * kGriffinDim> conv_w;
|
||||
ArrayT<float, kGriffinDim> conv_biases;
|
||||
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||
ArrayT<float, kGriffinDim> a;
|
||||
} griffin;
|
||||
};
|
||||
|
||||
ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
|
||||
ArrayT<float, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<float, kModelDim> pre_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
};
|
||||
|
||||
float ScaleWeights(float* data, size_t len) {
|
||||
|
|
@ -233,21 +246,21 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadWeights(
|
|||
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||
SCALE_WEIGHTS(qkv_einsum_w);
|
||||
} else {
|
||||
READ_WEIGHTS(griffin_linear_x_w);
|
||||
READ_WEIGHTS(griffin_linear_x_biases);
|
||||
READ_WEIGHTS(griffin_linear_y_w);
|
||||
READ_WEIGHTS(griffin_linear_y_biases);
|
||||
READ_WEIGHTS(griffin_linear_out_w);
|
||||
READ_WEIGHTS(griffin_linear_out_biases);
|
||||
READ_WEIGHTS(griffin_conv_w);
|
||||
READ_WEIGHTS(griffin_conv_biases);
|
||||
READ_WEIGHTS(griffin_gate_w);
|
||||
READ_WEIGHTS(griffin_gate_biases);
|
||||
READ_WEIGHTS(griffin_a);
|
||||
SCALE_WEIGHTS(griffin_linear_x_w);
|
||||
SCALE_WEIGHTS(griffin_linear_y_w);
|
||||
SCALE_WEIGHTS(griffin_linear_out_w);
|
||||
SCALE_WEIGHTS(griffin_gate_w);
|
||||
READ_WEIGHTS(griffin.linear_x_w);
|
||||
READ_WEIGHTS(griffin.linear_x_biases);
|
||||
READ_WEIGHTS(griffin.linear_y_w);
|
||||
READ_WEIGHTS(griffin.linear_y_biases);
|
||||
READ_WEIGHTS(griffin.linear_out_w);
|
||||
READ_WEIGHTS(griffin.linear_out_biases);
|
||||
READ_WEIGHTS(griffin.conv_w);
|
||||
READ_WEIGHTS(griffin.conv_biases);
|
||||
READ_WEIGHTS(griffin.gate_w);
|
||||
READ_WEIGHTS(griffin.gate_biases);
|
||||
READ_WEIGHTS(griffin.a);
|
||||
SCALE_WEIGHTS(griffin.linear_x_w);
|
||||
SCALE_WEIGHTS(griffin.linear_y_w);
|
||||
SCALE_WEIGHTS(griffin.linear_out_w);
|
||||
SCALE_WEIGHTS(griffin.gate_w);
|
||||
}
|
||||
READ_WEIGHTS(gating_einsum_w);
|
||||
READ_WEIGHTS(linear_w);
|
||||
|
|
@ -287,33 +300,54 @@ struct CompressedLayer {
|
|||
using TLayer = gcpp::Layer<TConfig>;
|
||||
using WeightT = typename TConfig::WeightT;
|
||||
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static constexpr size_t kHeads = TLayer::kHeads;
|
||||
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
||||
static constexpr size_t kModelDim = TLayer::kModelDim;
|
||||
static constexpr size_t kQKVDim = TLayer::kQKVDim;
|
||||
static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim;
|
||||
static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize;
|
||||
static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize;
|
||||
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
|
||||
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TLayer::kFFBiases;
|
||||
static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim;
|
||||
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
|
||||
|
||||
// Compressed Parameters
|
||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
|
||||
CompressedArray<float, kModelDim> ffw_output_biases;
|
||||
CompressedArray<float, kModelDim> attention_output_biases;
|
||||
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
|
||||
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
|
||||
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
|
||||
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
|
||||
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
|
||||
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
|
||||
griffin_gate_w;
|
||||
CompressedArray<float, kModelDim> griffin_a;
|
||||
CompressedArray<float, kModelDim> griffin_linear_y_biases;
|
||||
CompressedArray<float, kModelDim> griffin_linear_x_biases;
|
||||
CompressedArray<float, kModelDim> griffin_linear_out_biases;
|
||||
CompressedArray<float, kModelDim> griffin_conv_biases;
|
||||
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
|
||||
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
|
||||
template <class T, size_t N>
|
||||
using ArrayT = CompressedArray<T, N>;
|
||||
|
||||
union {
|
||||
struct {
|
||||
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<float, kAOBiasDim> attention_output_biases;
|
||||
};
|
||||
|
||||
struct {
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
ArrayT<float, kGriffinDim> linear_x_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
|
||||
ArrayT<float, kGriffinDim> conv_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||
ArrayT<float, kGriffinDim> a;
|
||||
} griffin;
|
||||
};
|
||||
|
||||
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
};
|
||||
|
||||
// Array instead of single large allocation for parallel mem init. Split out
|
||||
|
|
@ -387,10 +421,12 @@ struct Activations {
|
|||
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
|
||||
|
||||
// Griffin layer internal activations
|
||||
std::array<float, kBatchSize * kModelDim> griffin_x;
|
||||
std::array<float, kBatchSize * kModelDim> griffin_y;
|
||||
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
|
||||
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
|
||||
static constexpr size_t kGriffinDim =
|
||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_x;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_y;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
|
||||
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
|
||||
};
|
||||
|
||||
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we
|
||||
|
|
@ -541,10 +577,10 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset;
|
||||
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
|
||||
TwoMatVecAdd<true, kModelDim, kModelDim>(
|
||||
layer_weights->griffin_linear_x_w, layer_weights->griffin_linear_y_w, 0,
|
||||
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
/*add0=*/layer_weights->griffin_linear_x_biases.data(),
|
||||
/*add1=*/layer_weights->griffin_linear_y_biases.data(), /*out0=*/x,
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
||||
/*out1=*/y, pool);
|
||||
Gelu(y, kModelDim);
|
||||
|
||||
|
|
@ -564,13 +600,13 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
}
|
||||
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 = hn::Load(df, layer_weights->griffin_conv_biases.data() + i);
|
||||
auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||
auto wv0 = hn::Load(df, layer_weights->griffin_conv_w.data() +
|
||||
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
|
||||
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
||||
auto wv1 = hn::Load(df, layer_weights->griffin_conv_w.data() +
|
||||
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
|
||||
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||
|
|
@ -591,10 +627,10 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
constexpr size_t kMatrixSize = kHeadDim * kHeadDim;
|
||||
size_t head_offset = head * kHeadDim;
|
||||
TwoOfsMatVecAddLoop<true, kHeadDim, kHeadDim>(
|
||||
layer_weights->griffin_gate_w, kMatrixSize * head,
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (kHeads + head), x + head_offset,
|
||||
/*add0=*/layer_weights->griffin_gate_biases.data() + head_offset,
|
||||
/*add1=*/layer_weights->griffin_gate_biases.data() + kModelDim +
|
||||
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
|
||||
head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
|
|
@ -602,7 +638,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin_a.data() + head_offset, fn_mul);
|
||||
layer_weights->griffin.a.data() + head_offset, fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
// RNN scan
|
||||
|
|
@ -630,8 +666,8 @@ HWY_NOINLINE void GriffinRecurrent(
|
|||
// Final linear layer.
|
||||
float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<true, kModelDim, kModelDim>(
|
||||
layer_weights->griffin_linear_out_w, 0, x,
|
||||
layer_weights->griffin_linear_out_biases.data(), out_ptr, pool);
|
||||
layer_weights->griffin.linear_out_w, 0, x,
|
||||
layer_weights->griffin.linear_out_biases.data(), out_ptr, pool);
|
||||
}
|
||||
|
||||
template <size_t kBatchSize, typename LayerT, class TConfig>
|
||||
|
|
@ -1238,17 +1274,17 @@ void ForEachTensor(const Weights<TConfig>* weights,
|
|||
CALL_FUNC("qkv_ein", qkv_einsum_w);
|
||||
CALL_FUNC("att_ein", attn_vec_einsum_w);
|
||||
} else {
|
||||
CALL_FUNC("gr_lin_x_w", griffin_linear_x_w);
|
||||
CALL_FUNC("gr_lin_x_b", griffin_linear_x_biases);
|
||||
CALL_FUNC("gr_lin_y_w", griffin_linear_y_w);
|
||||
CALL_FUNC("gr_lin_y_b", griffin_linear_y_biases);
|
||||
CALL_FUNC("gr_lin_out_w", griffin_linear_out_w);
|
||||
CALL_FUNC("gr_lin_out_b", griffin_linear_out_biases);
|
||||
CALL_FUNC("gr_conv_w", griffin_conv_w);
|
||||
CALL_FUNC("gr_conv_b", griffin_conv_biases);
|
||||
CALL_FUNC("gr_gate_w", griffin_gate_w);
|
||||
CALL_FUNC("gr_gate_b", griffin_gate_biases);
|
||||
CALL_FUNC("gr_a", griffin_a);
|
||||
CALL_FUNC("gr_lin_x_w", griffin.linear_x_w);
|
||||
CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases);
|
||||
CALL_FUNC("gr_lin_y_w", griffin.linear_y_w);
|
||||
CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases);
|
||||
CALL_FUNC("gr_lin_out_w", griffin.linear_out_w);
|
||||
CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases);
|
||||
CALL_FUNC("gr_conv_w", griffin.conv_w);
|
||||
CALL_FUNC("gr_conv_b", griffin.conv_biases);
|
||||
CALL_FUNC("gr_gate_w", griffin.gate_w);
|
||||
CALL_FUNC("gr_gate_b", griffin.gate_biases);
|
||||
CALL_FUNC("gr_a", griffin.a);
|
||||
}
|
||||
CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||
|
||||
|
|
@ -1298,10 +1334,10 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> LoadCompressedWeights(
|
|||
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
||||
} else {
|
||||
layer_weights->griffin_linear_x_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin_linear_y_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin_linear_out_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin_gate_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]);
|
||||
}
|
||||
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->linear_w.set_scale(scales[scale_pos++]);
|
||||
|
|
|
|||
Loading…
Reference in New Issue