Refactor data structures to reduce memory usage

This commit is contained in:
RangerUFO 2024-04-10 19:35:23 +08:00
parent 54120a5571
commit 809bd0709d
1 changed files with 88 additions and 52 deletions

View File

@ -84,32 +84,45 @@ struct Layer {
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr size_t kAOBiaseDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
std::array<float, kAttVecEinsumWSize> attn_vec_einsum_w;
std::array<float, kQKVEinsumWSize> qkv_einsum_w;
std::array<float, kGatingEinsumWSize> gating_einsum_w;
std::array<float, kModelDim * kFFHiddenDim> linear_w;
std::array<float, kModelDim> pre_attention_norm_scale;
std::array<float, kModelDim> pre_ffw_norm_scale;
// These fields are only used by Griffin, and do not affect loading of the
// model as it is done per-member.
// TODO(veluca): pull weights that are never used at the same time into a
// union or otherwise reduce the memory usage.
std::array<float, 2 * kFFHiddenDim> ffw_gating_biases;
std::array<float, kModelDim> ffw_output_biases;
std::array<float, kModelDim> attention_output_biases;
template <class T, size_t N>
using ArrayT = std::array<T, N>;
std::array<float, kModelDim * kModelDim> griffin_linear_y_w;
std::array<float, kModelDim> griffin_linear_y_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_x_w;
std::array<float, kModelDim> griffin_linear_x_biases;
std::array<float, kModelDim * kModelDim> griffin_linear_out_w;
std::array<float, kModelDim> griffin_linear_out_biases;
std::array<float, kModelDim> griffin_conv_biases;
std::array<float, kModelDim * kModelDim / TConfig::kHeads * 2> griffin_gate_w;
std::array<float, kModelDim * 2> griffin_gate_biases;
std::array<float, kModelDim> griffin_a;
std::array<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
union {
struct {
ArrayT<float, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<float, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiaseDim> attention_output_biases;
};
struct {
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_x_w;
ArrayT<float, kGriffinDim> griffin_linear_x_biases;
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_y_w;
ArrayT<float, kGriffinDim> griffin_linear_y_biases;
ArrayT<float, kGriffinDim * kGriffinDim> griffin_linear_out_w;
ArrayT<float, kGriffinDim> griffin_linear_out_biases;
ArrayT<float, kConv1dWidth * kGriffinDim> griffin_conv_w;
ArrayT<float, kGriffinDim> griffin_conv_biases;
ArrayT<float, kGriffinDim * kGriffinDim / kHeads * 2> griffin_gate_w;
ArrayT<float, kGriffinDim * 2> griffin_gate_biases;
ArrayT<float, kGriffinDim> griffin_a;
};
};
ArrayT<float, kGatingEinsumWSize> gating_einsum_w;
ArrayT<float, kModelDim * kFFHiddenDim> linear_w;
ArrayT<float, kModelDim> pre_attention_norm_scale;
ArrayT<float, kModelDim> pre_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
float ScaleWeights(float* data, size_t len) {
@ -287,33 +300,54 @@ struct CompressedLayer {
using TLayer = gcpp::Layer<TConfig>;
using WeightT = typename TConfig::WeightT;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kHeads = TLayer::kHeads;
static constexpr size_t kKVHeads = TLayer::kKVHeads;
static constexpr size_t kModelDim = TLayer::kModelDim;
static constexpr size_t kQKVDim = TLayer::kQKVDim;
static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize;
static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize;
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
static constexpr bool kFFBiases = TLayer::kFFBiases;
static constexpr size_t kAOBiaseDim = TLayer::kAOBiaseDim;
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
// Compressed Parameters
// We don't yet have an RMSNorm that accepts all WeightT.
CompressedArray<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
CompressedArray<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
CompressedArray<float, 2 * kFFHiddenDim> ffw_gating_biases;
CompressedArray<float, kModelDim> ffw_output_biases;
CompressedArray<float, kModelDim> attention_output_biases;
CompressedArray<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
CompressedArray<WeightT, kModelDim * kFFHiddenDim> linear_w;
CompressedArray<WeightT, TLayer::kQKVEinsumWSize> qkv_einsum_w;
CompressedArray<WeightT, TLayer::kAttVecEinsumWSize> attn_vec_einsum_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_y_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_x_w;
CompressedArray<WeightT, kModelDim * kModelDim> griffin_linear_out_w;
CompressedArray<WeightT, kModelDim * kModelDim / TConfig::kHeads * 2>
griffin_gate_w;
CompressedArray<float, kModelDim> griffin_a;
CompressedArray<float, kModelDim> griffin_linear_y_biases;
CompressedArray<float, kModelDim> griffin_linear_x_biases;
CompressedArray<float, kModelDim> griffin_linear_out_biases;
CompressedArray<float, kModelDim> griffin_conv_biases;
CompressedArray<float, kModelDim * 2> griffin_gate_biases;
CompressedArray<float, TConfig::kConv1dWidth * kModelDim> griffin_conv_w;
template <class T, size_t N>
using ArrayT = CompressedArray<T, N>;
union {
struct {
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiaseDim> attention_output_biases;
};
struct {
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_x_w;
ArrayT<float, kGriffinDim> griffin_linear_x_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_y_w;
ArrayT<float, kGriffinDim> griffin_linear_y_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> griffin_linear_out_w;
ArrayT<float, kGriffinDim> griffin_linear_out_biases;
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> griffin_conv_w;
ArrayT<float, kGriffinDim> griffin_conv_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> griffin_gate_w;
ArrayT<float, kGriffinDim * 2> griffin_gate_biases;
ArrayT<float, kGriffinDim> griffin_a;
};
};
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all WeightT.
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
// Array instead of single large allocation for parallel mem init. Split out
@ -387,10 +421,12 @@ struct Activations {
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// Griffin layer internal activations
std::array<float, kBatchSize * kModelDim> griffin_x;
std::array<float, kBatchSize * kModelDim> griffin_y;
std::array<float, kBatchSize * kModelDim> griffin_gate_x;
std::array<float, kBatchSize * kModelDim> griffin_multiplier;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
std::array<float, kBatchSize * kGriffinDim> griffin_x;
std::array<float, kBatchSize * kGriffinDim> griffin_y;
std::array<float, kBatchSize * kGriffinDim> griffin_gate_x;
std::array<float, kBatchSize * kGriffinDim> griffin_multiplier;
};
// GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we