Refactor data structures to reduce memory usage

This commit is contained in:
RangerUFO 2024-04-10 19:35:23 +08:00
parent 54120a5571
commit 809bd0709d
1 changed files with 88 additions and 52 deletions

View File

@ -84,32 +84,45 @@ struct Layer {
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector) // 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr size_t kAOBiaseDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
std::array<float, kAttVecEinsumWSize> attn_vec_einsum_w; template <class T, size_t N>
std::array<float, kQKVEinsumWSize> qkv_einsum_w; using ArrayT = std::array<T, N>;
std::array<float, kGatingEinsumWSize> gating_einsum_w;
std::array<float, kModelDim * kFFHiddenDim> linear_w;
std::array<float, kModelDim> pre_attention_norm_scale;
std::array<float, kModelDim> pre_ffw_norm_scale;
// These fields are only used by Griffin, and do not affect loading of the
// model as it is done per-member.
// TODO(veluca): pull weights that are never used at the same time into a
// union or otherwise reduce the memory usage.
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
std::array<float, kModelDim> ffw_output_biases;
std::array<float, kModelDim> attention_output_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_y_w; union {
std::array<float, kModelDim> griffin_linear_y_biases; struct {
std::array<float, kModelDim * kModelDim> griffin_linear_x_w; ArrayT<float, kAttVecEinsumWSize> attn_vec_einsum_w;
std::array<float, kModelDim> griffin_linear_x_biases; ArrayT<float, kQKVEinsumWSize> qkv_einsum_w;
std::array<float, kModelDim * kModelDim> griffin_linear_out_w; ArrayT<float, kAOBiaseDim> attention_output_biases;
std::array<float, kModelDim> griffin_linear_out_biases; };
std::array<float, kModelDim> griffin_conv_biases;
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w; struct {
std::array<float, kModelDim * 2> griffin_gate_biases; ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_x_w;
std::array<float, kModelDim> griffin_a; ArrayT<float, kGriffinDim> griffin_linear_x_biases;
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w; ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_y_w;
ArrayT<float, kGriffinDim> griffin_linear_y_biases;
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_out_w;
ArrayT<float, kGriffinDim> griffin_linear_out_biases;
ArrayT<float, kConv1dWidth * kGriffinDim> griffin_conv_w;
ArrayT<float, kGriffinDim> griffin_conv_biases;
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> griffin_gate_w;
ArrayT<float, kGriffinDim * 2> griffin_gate_biases;
ArrayT<float, kGriffinDim> griffin_a;
};
};
ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
ArrayT<float, kModelDim> pre_attention_norm_scale;
ArrayT<float, kModelDim> pre_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
}; };
float ScaleWeights(float* data, size_t len) { float ScaleWeights(float* data, size_t len) {
@ -287,33 +300,54 @@ struct CompressedLayer {
using TLayer = gcpp::Layer<TConfig>; using TLayer = gcpp::Layer<TConfig>;
using WeightT = typename TConfig::WeightT; using WeightT = typename TConfig::WeightT;
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kHeads = TLayer::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kKVHeads = TLayer::kKVHeads;
static constexpr size_t kModelDim = TLayer::kModelDim;
static constexpr size_t kQKVDim = TLayer::kQKVDim;
static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize;
static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize;
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
static constexpr bool kFFBiases = TLayer::kFFBiases;
static constexpr size_t kAOBiaseDim = TLayer::kAOBiaseDim;
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
// Compressed Parameters // Compressed Parameters
// We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
CompressedArray<float, kModelDim> ffw_output_biases;
CompressedArray<float, kModelDim> attention_output_biases;
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w; template <class T, size_t N>
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w; using ArrayT = CompressedArray<T, N>;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2> union {
griffin_gate_w; struct {
CompressedArray<float, kModelDim> griffin_a; ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
CompressedArray<float, kModelDim> griffin_linear_y_biases; ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<float, kModelDim> griffin_linear_x_biases; ArrayT<float, kAOBiaseDim> attention_output_biases;
CompressedArray<float, kModelDim> griffin_linear_out_biases; };
CompressedArray<float, kModelDim> griffin_conv_biases;
CompressedArray<float, kModelDim * 2> griffin_gate_biases; struct {
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w; ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_x_w;
ArrayT<float, kGriffinDim> griffin_linear_x_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_y_w;
ArrayT<float, kGriffinDim> griffin_linear_y_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_out_w;
ArrayT<float, kGriffinDim> griffin_linear_out_biases;
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> griffin_conv_w;
ArrayT<float, kGriffinDim> griffin_conv_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> griffin_gate_w;
ArrayT<float, kGriffinDim * 2> griffin_gate_biases;
ArrayT<float, kGriffinDim> griffin_a;
};
};
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all WeightT.
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
}; };
// Array instead of single large allocation for parallel mem init. Split out // Array instead of single large allocation for parallel mem init. Split out
@ -387,10 +421,12 @@ struct Activations {
std::array<float, kBatchSize * TConfig::kVocabSize> logits; std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// Griffin layer internal activations // Griffin layer internal activations
std::array<float, kBatchSize * kModelDim> griffin_x; static constexpr size_t kGriffinDim =
std::array<float, kBatchSize * kModelDim> griffin_y; TConfig::kGriffinLayers > 0 ? kModelDim : 0;
std::array<float, kBatchSize * kModelDim> griffin_gate_x; std::array<float, kBatchSize * kGriffinDim> griffin_x;
std::array<float, kBatchSize * kModelDim> griffin_multiplier; std::array<float, kBatchSize * kGriffinDim> griffin_y;
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
}; };
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we