From a3a75b77f9ab15761b8ec9e46cbdb8b1d0287928 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Mon, 10 Jun 2024 14:34:24 +0000 Subject: [PATCH 1/2] Use CompressedWeights> in backpropagation. kWeightsAreCompressed are removed and LoadRawWeights is moved to compress_weights.cc --- backprop/backward-inl.h | 61 +++++++------- backprop/backward.cc | 5 +- backprop/backward_test.cc | 9 +- backprop/forward.cc | 7 +- backprop/optimize_test.cc | 26 +++--- backprop/optimizer.cc | 39 +++++---- compression/compress.h | 3 + gemma/compress_weights.cc | 152 ++++++++++++++++++++++++++++++++++ gemma/gemma.cc | 13 +-- gemma/weights.cc | 170 ++------------------------------------ gemma/weights.h | 100 ++++++++++++---------- 11 files changed, 300 insertions(+), 285 deletions(-) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index fb7a3a1..96f1f08 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -28,7 +28,6 @@ #include "backprop/prompt.h" #include "gemma/activations.h" #include "gemma/common.h" -#include "gemma/weights.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -51,12 +50,12 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; template -void MatMulVJP(const std::array& weights, - const float* HWY_RESTRICT x, // num_tokens * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows +void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, + const float* HWY_RESTRICT x, // num_tokens * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows size_t num_tokens, - std::array& grad_w, - float* HWY_RESTRICT grad_x, // num_tokens * kCols + float* HWY_RESTRICT grad_w, // kRows * kCols, + float* HWY_RESTRICT grad_x, // num_tokens * kCols hwy::ThreadPool& pool) { hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { @@ -72,12 +71,12 @@ void MatMulVJP(const std::array& weights, template void MultiHeadMatMulVJP( - const std::array& weights, - const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows + const float* HWY_RESTRICT weights, // kHeads * kRows * kCols + const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows size_t num_tokens, - std::array& grad_w, - float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols + float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols + float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols hwy::ThreadPool& pool) { hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { @@ -166,12 +165,12 @@ static HWY_NOINLINE void InputEmbeddingVJP( } } -template -void LayerVJP(const Layer& weights, +template typename LayerT> +void LayerVJP(const LayerT& weights, const ForwardLayer& forward, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, - Layer& grad, + LayerT& grad, ForwardLayer& backward, hwy::ThreadPool& pool) { static constexpr size_t kModelDim = TConfig::kModelDim; @@ -184,8 +183,8 @@ void LayerVJP(const Layer& weights, HWY_ASSERT(num_tokens <= kSeqLen); MatMulVJP( - weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad, - num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(), + weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad, + num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { @@ -210,9 +209,9 @@ void LayerVJP(const Layer& weights, } MatMulVJP( - weights.gating_einsum_w, + weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), - num_tokens, grad.gating_einsum_w, + num_tokens, grad.gating_einsum_w.data(), backward.bf_pre_ffw_rms_out.data(), pool); RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), @@ -230,9 +229,9 @@ void LayerVJP(const Layer& weights, num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0])); MultiHeadMatMulVJP( - weights.attn_vec_einsum_w, forward.att_out.data(), + weights.attn_vec_einsum_w.data(), forward.att_out.data(), backward.attention_out.data(), num_tokens, - grad.attn_vec_einsum_w, backward.att_out.data(), pool); + grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool); for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { @@ -293,9 +292,9 @@ void LayerVJP(const Layer& weights, } MatMulVJP( - weights.qkv_einsum_w, forward.pre_att_rms_out.data(), - backward.qkv.data(), num_tokens, - grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool); + weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), + backward.qkv.data(), num_tokens, + grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(), backward.pre_att_rms_out.data(), @@ -345,11 +344,12 @@ static HWY_NOINLINE void CrossEntropyLossGrad( } } -template +template typename WeightsT, + template typename LayerT> void CrossEntropyLossBackwardPass(const Prompt& prompt, - const Weights& weights, + const WeightsT& weights, const ForwardPass& forward, - Weights& grad, + WeightsT& grad, ForwardPass& backward, hwy::ThreadPool& pool) { static constexpr size_t kVocabSize = TConfig::kVocabSize; @@ -379,9 +379,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, } MatMulVJP( - weights.embedder_input_embedding, forward.final_norm_output.data(), + weights.embedder_input_embedding.data(), forward.final_norm_output.data(), backward.logits.data(), num_tokens, - grad.embedder_input_embedding, backward.final_norm_output.data(), + grad.embedder_input_embedding.data(), backward.final_norm_output.data(), pool); RMSNormVJP(weights.final_norm_scale.data(), @@ -398,8 +398,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, float* next_layer_grad = layer + 1 < kLayers ? backward.layers[layer + 1].input.data() : backward.final_layer_output.data(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, - num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool); + LayerVJP( + *weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool); } InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, diff --git a/backprop/backward.cc b/backprop/backward.cc index bc1a630..87ede98 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -42,13 +42,14 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, ByteStorageT& grad_u8, ByteStorageT& backward_u8, hwy::ThreadPool& pool) { - using TWeights = WeightsF; + using TWeights = CompressedWeights; const auto& weights = *reinterpret_cast(weights_u8.get()); auto& grad = *reinterpret_cast(grad_u8.get()); using TAct = ForwardPass; const auto& forward = *reinterpret_cast(forward_u8.get()); auto& backward = *reinterpret_cast(backward_u8.get()); - CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool); + CrossEntropyLossBackwardPass( + prompt, weights, forward, grad, backward, pool); } void CrossEntropyLossBackwardPassT(Model model, diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 9b595a1..4a3e4cc 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -84,8 +84,8 @@ void TestMatMulVJP() { }; hwy::ZeroBytes(&grad, sizeof(grad)); - MatMulVJP(weights, x.data(), dy.data(), kTokens, - grad, dx.data(), pool); + MatMulVJP(weights.data(), x.data(), dy.data(), kTokens, + grad.data(), dx.data(), pool); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); @@ -130,7 +130,8 @@ void TestMultiHeadMatMulVJP() { hwy::ZeroBytes(&grad, sizeof(grad)); MultiHeadMatMulVJP( - weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); + weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(), + pool); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); @@ -235,7 +236,7 @@ void TestEndToEnd() { EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); grad.clear(); - CrossEntropyLossBackwardPass( + CrossEntropyLossBackwardPass( prompt, weights.get(), forward1.get(), grad.get(), backward.get(), pool); diff --git a/backprop/forward.cc b/backprop/forward.cc index 0880ee2..1c8670e 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -41,11 +41,12 @@ float CrossEntropyLossForwardPass(const Prompt& prompt, ByteStorageT& forward_u8, hwy::ThreadPool& pool) { const auto& weights = - *reinterpret_cast*>(weights_u8.get()); + *reinterpret_cast*>(weights_u8.get()); auto& forward = *reinterpret_cast*>(forward_u8.get()); - return CrossEntropyLossForwardPass( - prompt.tokens, prompt.context_size, weights, forward, pool); + return + CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights, forward, pool); } float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 2a79049..ad82353 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -34,19 +34,17 @@ namespace gcpp { TEST(OptimizeTest, GradientDescent) { - if (kWeightsAreCompressed) return; - hwy::ThreadPool pool(0); std::mt19937 gen(42); Model model_type = Model::GEMMA_TINY; Type weight_type = Type::kF32; - ByteStorageT grad = - CallForModelAndWeight(model_type, weight_type, pool); - ByteStorageT grad_m = - CallForModelAndWeight(model_type, weight_type, pool); - ByteStorageT grad_v = - CallForModelAndWeight(model_type, weight_type, pool); + ByteStorageT grad = CallForModelAndWeight( + model_type, weight_type, pool); + ByteStorageT grad_m = CallForModelAndWeight( + model_type, weight_type, pool); + ByteStorageT grad_v = CallForModelAndWeight( + model_type, weight_type, pool); ByteStorageT forward = CallForModelAndWeight(model_type, weight_type); ByteStorageT backward = @@ -88,10 +86,10 @@ TEST(OptimizeTest, GradientDescent) { }; RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen); - CallForModelAndWeight(model_type, weight_type, grad_m, - pool); - CallForModelAndWeight(model_type, weight_type, grad_v, - pool); + CallForModelAndWeight( + model_type, weight_type, grad_m, pool); + CallForModelAndWeight( + model_type, weight_type, grad_v, pool); printf("Initial weights:\n"); LogWeightStats(model_type, weight_type, gemma.Weights()); @@ -109,8 +107,8 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - CallForModelAndWeight(model_type, weight_type, grad, - pool); + CallForModelAndWeight( + model_type, weight_type, grad, pool); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 4302f9c..93d3164 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -32,10 +32,11 @@ class WeightInitializer { WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} template - void operator()(const char* name, std::array& tensor) { + void operator()(const char* name, CompressedArray& tensor) { for (size_t i = 0; i < N; ++i) { tensor[i] = dist_(gen_); } + tensor.set_scale(1.0f); } private: std::normal_distribution dist_; @@ -46,11 +47,12 @@ template struct RandInitWeightsT { void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool, std::mt19937& gen) const { - auto& weights = *reinterpret_cast*>(weights_u8.get()); + auto& weights = + *reinterpret_cast*>(weights_u8.get()); // TODO(szabadka) Use the same weight initialization method as in the python // version. WeightInitializer init(gen); - ForEachTensor1(init, weights); + ForEachTensor1(init, weights); } }; @@ -63,10 +65,11 @@ class AdamUpdater { norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {} template - void operator()(const char* name, const std::array& grad, - std::array& weights, - std::array& grad_m, - std::array& grad_v) { + void operator()(const char* name, + const CompressedArray& grad, + CompressedArray& weights, + CompressedArray& grad_m, + CompressedArray& grad_v) { for (size_t i = 0; i < kCapacity; ++i) { grad_m[i] *= beta1_; grad_m[i] += cbeta1_ * grad[i]; @@ -95,13 +98,13 @@ struct AdamUpdateT { float beta2, float epsilon, size_t t, const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { - const auto& grad = - *reinterpret_cast*>(grad_u8.get()); - auto& weights = *reinterpret_cast*>(weights_u8.get()); - auto& grad_m = *reinterpret_cast*>(grad_m_u8.get()); - auto& grad_v = *reinterpret_cast*>(grad_v_u8.get()); + using TWeights = CompressedWeights; + const auto& grad = *reinterpret_cast(grad_u8.get()); + auto& weights = *reinterpret_cast(weights_u8.get()); + auto& grad_m = *reinterpret_cast(grad_m_u8.get()); + auto& grad_v = *reinterpret_cast(grad_v_u8.get()); AdamUpdater updater(alpha, beta1, beta2, epsilon, t); - ForEachTensor4(updater, grad, weights, grad_m, grad_v); + ForEachTensor4(updater, grad, weights, grad_m, grad_v); } }; @@ -110,17 +113,17 @@ struct AdamUpdateT { void RandInitWeights(Model model_type, Type weight_type, const ByteStorageT& weights, hwy::ThreadPool& pool, std::mt19937& gen) { - CallForModelAndWeight(model_type, weight_type, weights, - pool, gen); + HWY_ASSERT(weight_type == Type::kF32); + CallForModel(model_type, weights, pool, gen); } void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, float alpha, float beta1, float beta2, float epsilon, size_t t, const ByteStorageT& weights, const ByteStorageT& grad_m, const ByteStorageT& grad_v, hwy::ThreadPool& pool) { - CallForModelAndWeight(model_type, weight_type, grad, alpha, - beta1, beta2, epsilon, t, weights, grad_m, - grad_v, pool); + HWY_ASSERT(weight_type == Type::kF32); + CallForModel(model_type, grad, alpha, beta1, beta2, + epsilon, t, weights, grad_m, grad_v, pool); } } // namespace gcpp diff --git a/compression/compress.h b/compression/compress.h index edb7fdb..344cabc 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -79,6 +79,9 @@ class CompressedArray { MatT* data() { return data_.data(); } const MatT* data() const { return data_.data(); } + MatT& operator[](size_t pos) { return data_[pos]; } + const MatT& operator[](size_t pos) const { return data_[pos]; } + float scale() const { return scale_[0]; } void set_scale(float scale) { scale_[0] = scale; } diff --git a/gemma/compress_weights.cc b/gemma/compress_weights.cc index 5bfa5e0..d92b2b3 100644 --- a/gemma/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -41,10 +41,162 @@ #include "gemma/weights.h" #include "util/args.h" #include "hwy/base.h" +#include "hwy/profiler.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { +// Setting this to true disables fread() calls that read the model file. +constexpr bool kDryRunFread = false; + +namespace { +float ScaleWeights(float* data, size_t len) { + float maxabs = 0.0; + for (size_t i = 0; i < len; ++i) { + maxabs = std::max(maxabs, std::abs(data[i])); + } + const float kMaxRange = 1.875f; + if (maxabs <= kMaxRange) { + return 1.0f; + } + const float scale = maxabs / kMaxRange; + const float inv_scale = 1.0f / scale; + for (size_t i = 0; i < len; ++i) { + data[i] *= inv_scale; + } + return scale; +} + +#define READ_WEIGHTS(name) \ + do { \ + do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \ + } while (0) + +#define SCALE_WEIGHTS(name) \ + do { \ + if (ok && !kDryRunFread && scale_for_compression) { \ + weights->scales[scale_pos++] = \ + ScaleWeights(layer_view->name.data(), layer_view->name.size()); \ + } \ + } while (0) + +template +struct LoadRawWeightsT { + ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool, + bool scale_for_compression) const { + PROFILER_ZONE("Startup.LoadWeights"); + if (!checkpoint.Exists()) { + HWY_ABORT("The model weights file '%s' does not exist.", + checkpoint.path.c_str()); + } + + ByteStorageT weights_u8 = AllocateWeightsF()(pool); + auto* weights = reinterpret_cast*>(weights_u8.get()); + + size_t scale_pos = 0; + FILE* fptr; + if constexpr (kDryRunFread) { + fprintf(stderr, "Dry-Run, not reading model-file.\n"); + } else { + fptr = fopen(checkpoint.path.c_str(), "rb"); + if (fptr == nullptr) { + HWY_ABORT("Failed to open model file %s - does it exist?", + checkpoint.path.c_str()); + } + } + bool ok = true; + uint64_t total_size = 0; + auto do_fread = [&](void* var, int layer, const char* name, size_t size) { + if (layer == -1) { + fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name); + } else { + fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer, + size, name); + } + if constexpr (!kDryRunFread) { + ok &= 1 == fread(var, size, 1, fptr); + total_size += size; + } + }; + do_fread(&(weights->embedder_input_embedding), -1, + "embedder_input_embedding", + sizeof(weights->embedder_input_embedding)); + do_fread(&(weights->final_norm_scale), -1, "final_norm_scale", + sizeof(weights->final_norm_scale)); + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + auto type = TConfig::kLayerConfig[layer]; + LayerF* layer_view = weights->GetLayer(layer); + + // Make sure we don't have uninitialized memory. + hwy::ZeroBytes(layer_view, sizeof(*layer_view)); + if (type == LayerAttentionType::kGemma) { + READ_WEIGHTS(attn_vec_einsum_w); + READ_WEIGHTS(qkv_einsum_w); + SCALE_WEIGHTS(attn_vec_einsum_w); + SCALE_WEIGHTS(qkv_einsum_w); + } else { + READ_WEIGHTS(griffin.linear_x_w); + READ_WEIGHTS(griffin.linear_x_biases); + READ_WEIGHTS(griffin.linear_y_w); + READ_WEIGHTS(griffin.linear_y_biases); + READ_WEIGHTS(griffin.linear_out_w); + READ_WEIGHTS(griffin.linear_out_biases); + READ_WEIGHTS(griffin.conv_w); + READ_WEIGHTS(griffin.conv_biases); + READ_WEIGHTS(griffin.gate_w); + READ_WEIGHTS(griffin.gate_biases); + READ_WEIGHTS(griffin.a); + SCALE_WEIGHTS(griffin.linear_x_w); + SCALE_WEIGHTS(griffin.linear_y_w); + SCALE_WEIGHTS(griffin.linear_out_w); + SCALE_WEIGHTS(griffin.gate_w); + } + READ_WEIGHTS(gating_einsum_w); + READ_WEIGHTS(linear_w); + SCALE_WEIGHTS(gating_einsum_w); + SCALE_WEIGHTS(linear_w); + READ_WEIGHTS(pre_attention_norm_scale); + READ_WEIGHTS(pre_ffw_norm_scale); + if (TConfig::kPostNormScale) { + READ_WEIGHTS(post_attention_norm_scale); + READ_WEIGHTS(post_ffw_norm_scale); + } + if (TConfig::kFFBiases) { + READ_WEIGHTS(ffw_gating_biases); + READ_WEIGHTS(ffw_output_biases); + } + if (TConfig::kSoftmaxAttnOutputBiases && + type == LayerAttentionType::kGemma) { + READ_WEIGHTS(attention_output_biases); + } + } + if (!ok) { + HWY_ABORT( + "Failed to read from %s - might be a directory, or too small? " + "expected size: %d kB", + checkpoint.path.c_str(), static_cast(total_size >> 10)); + } + if (!kDryRunFread) { + HWY_ASSERT(0 == fclose(fptr)); + if (scale_for_compression) { + HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); + } + } + return weights_u8; + } +}; + +#undef READ_WEIGHTS +#undef SCALE_WEIGHTS +} // namespace + +ByteStorageT LoadRawWeights(const Path& weights, Model model_type, + Type weight_type, hwy::ThreadPool& pool, + bool scale_for_compression) { + return CallForModelAndWeight( + model_type, weight_type, weights, pool, scale_for_compression); +} + struct Args : public ArgsBase { static constexpr size_t kDefaultNumThreads = ~size_t{0}; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index ed8c36b..e095657 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -717,8 +717,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, } template -const WeightsT& GetWeights(const ByteStorageT& weights_u8) { - return *reinterpret_cast*>(weights_u8.get()); +const CompressedWeights& GetWeights(const ByteStorageT& weights_u8) { + return *reinterpret_cast*>(weights_u8.get()); } template @@ -735,7 +735,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info, LayersOutputT* layers_output) { - const WeightsT& weights = GetWeights(weights_u8); + const CompressedWeights& weights = GetWeights(weights_u8); auto& prefill_activations = GetActivations(prefill_u8); auto& activations = GetActivations(decode_u8); @@ -878,7 +878,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, tokenizer_(tokenizer_path), model_type_(model_type), weight_type_(weight_type) { - weights_u8_ = LoadWeights(weights, model_type, weight_type, pool); + weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool); prefill_u8_ = CallForModelAndWeight(model_type, weight_type); decode_u8_ = CallForModelAndWeight(model_type, weight_type); } @@ -889,8 +889,9 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, tokenizer_(std::move(tokenizer)), model_type_(model_type), weight_type_(weight_type) { - weights_u8_ = - CallForModelAndWeight(model_type, weight_type, pool); + HWY_ASSERT(weight_type == Type::kF32); + weights_u8_ = CallForModel( + model_type, pool); prefill_u8_ = CallForModelAndWeight(model_type, weight_type); decode_u8_ = CallForModelAndWeight(model_type, weight_type); } diff --git a/gemma/weights.cc b/gemma/weights.cc index eb03043..a138474 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -29,157 +29,6 @@ namespace gcpp { -// Setting this to true disables fread() calls that read the model file. -constexpr bool kDryRunFread = false; - -namespace { -float ScaleWeights(float* data, size_t len) { - float maxabs = 0.0; - for (size_t i = 0; i < len; ++i) { - maxabs = std::max(maxabs, std::abs(data[i])); - } - const float kMaxRange = 1.875f; - if (maxabs <= kMaxRange) { - return 1.0f; - } - const float scale = maxabs / kMaxRange; - const float inv_scale = 1.0f / scale; - for (size_t i = 0; i < len; ++i) { - data[i] *= inv_scale; - } - return scale; -} - -#define READ_WEIGHTS(name) \ - do { \ - do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \ - } while (0) - -#define SCALE_WEIGHTS(name) \ - do { \ - if (ok && !kDryRunFread && scale_for_compression) { \ - weights->scales[scale_pos++] = \ - ScaleWeights(layer_view->name.data(), layer_view->name.size()); \ - } \ - } while (0) - -template -struct LoadRawWeightsT { - ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool, - bool scale_for_compression) const { - PROFILER_ZONE("Startup.LoadWeights"); - if (!checkpoint.Exists()) { - HWY_ABORT("The model weights file '%s' does not exist.", - checkpoint.path.c_str()); - } - - ByteStorageT weights_u8 = AllocateWeightsF()(pool); - auto* weights = reinterpret_cast*>(weights_u8.get()); - - size_t scale_pos = 0; - FILE* fptr; - if constexpr (kDryRunFread) { - fprintf(stderr, "Dry-Run, not reading model-file.\n"); - } else { - fptr = fopen(checkpoint.path.c_str(), "rb"); - if (fptr == nullptr) { - HWY_ABORT("Failed to open model file %s - does it exist?", - checkpoint.path.c_str()); - } - } - bool ok = true; - uint64_t total_size = 0; - auto do_fread = [&](void* var, int layer, const char* name, size_t size) { - if (layer == -1) { - fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name); - } else { - fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer, - size, name); - } - if constexpr (!kDryRunFread) { - ok &= 1 == fread(var, size, 1, fptr); - total_size += size; - } - }; - do_fread(&(weights->embedder_input_embedding), -1, - "embedder_input_embedding", - sizeof(weights->embedder_input_embedding)); - do_fread(&(weights->final_norm_scale), -1, "final_norm_scale", - sizeof(weights->final_norm_scale)); - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - auto type = TConfig::kLayerConfig[layer]; - LayerF* layer_view = weights->GetLayer(layer); - - // Make sure we don't have uninitialized memory. - hwy::ZeroBytes(layer_view, sizeof(*layer_view)); - if (type == LayerAttentionType::kGemma) { - READ_WEIGHTS(attn_vec_einsum_w); - READ_WEIGHTS(qkv_einsum_w); - SCALE_WEIGHTS(attn_vec_einsum_w); - SCALE_WEIGHTS(qkv_einsum_w); - } else { - READ_WEIGHTS(griffin.linear_x_w); - READ_WEIGHTS(griffin.linear_x_biases); - READ_WEIGHTS(griffin.linear_y_w); - READ_WEIGHTS(griffin.linear_y_biases); - READ_WEIGHTS(griffin.linear_out_w); - READ_WEIGHTS(griffin.linear_out_biases); - READ_WEIGHTS(griffin.conv_w); - READ_WEIGHTS(griffin.conv_biases); - READ_WEIGHTS(griffin.gate_w); - READ_WEIGHTS(griffin.gate_biases); - READ_WEIGHTS(griffin.a); - SCALE_WEIGHTS(griffin.linear_x_w); - SCALE_WEIGHTS(griffin.linear_y_w); - SCALE_WEIGHTS(griffin.linear_out_w); - SCALE_WEIGHTS(griffin.gate_w); - } - READ_WEIGHTS(gating_einsum_w); - READ_WEIGHTS(linear_w); - SCALE_WEIGHTS(gating_einsum_w); - SCALE_WEIGHTS(linear_w); - READ_WEIGHTS(pre_attention_norm_scale); - READ_WEIGHTS(pre_ffw_norm_scale); - if (TConfig::kPostNormScale) { - READ_WEIGHTS(post_attention_norm_scale); - READ_WEIGHTS(post_ffw_norm_scale); - } - if (TConfig::kFFBiases) { - READ_WEIGHTS(ffw_gating_biases); - READ_WEIGHTS(ffw_output_biases); - } - if (TConfig::kSoftmaxAttnOutputBiases && - type == LayerAttentionType::kGemma) { - READ_WEIGHTS(attention_output_biases); - } - } - if (!ok) { - HWY_ABORT( - "Failed to read from %s - might be a directory, or too small? " - "expected size: %d kB", - checkpoint.path.c_str(), static_cast(total_size >> 10)); - } - if (!kDryRunFread) { - HWY_ASSERT(0 == fclose(fptr)); - if (scale_for_compression) { - HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); - } - } - return weights_u8; - } -}; - -#undef READ_WEIGHTS -#undef SCALE_WEIGHTS -} // namespace - -ByteStorageT LoadRawWeights(const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool, - bool scale_for_compression) { - return CallForModelAndWeight( - model_type, weight_type, weights, pool, scale_for_compression); -} - namespace { template struct LoadCompressedWeightsT { @@ -234,16 +83,6 @@ ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, weights, pool); } -ByteStorageT LoadWeights(const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool) { - if constexpr (kWeightsAreCompressed) { - return LoadCompressedWeights(weights, model_type, weight_type, pool); - } else { - return LoadRawWeights(weights, model_type, weight_type, pool, - /*scale_for_compression=*/false); - } -} - namespace { void LogVec(const char* name, const float* data, size_t len) { hwy::Stats stats; @@ -257,7 +96,7 @@ void LogVec(const char* name, const float* data, size_t len) { class WeightLogger { public: template - void operator()(const char* name, const std::array& tensor) { + void operator()(const char* name, const CompressedArray& tensor) { LogVec(name, tensor.data(), N); total_weights += N; } @@ -268,9 +107,9 @@ template struct LogWeightStatsT { void operator()(const ByteStorageT& weights_u8) const { const auto& weights = - *reinterpret_cast*>(weights_u8.get()); + *reinterpret_cast*>(weights_u8.get()); WeightLogger logger; - ForEachTensor1(logger, weights); + ForEachTensor1(logger, weights); printf("%-20s %12zu\n", "Total", logger.total_weights); } }; @@ -278,7 +117,8 @@ struct LogWeightStatsT { void LogWeightStats(gcpp::Model model_type, Type weight_type, const ByteStorageT& weights) { - CallForModelAndWeight(model_type, weight_type, weights); + HWY_ASSERT(weight_type == Type::kF32); + CallForModel(model_type, weights); } } // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index c0bda43..d9f5c0f 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -25,9 +25,6 @@ namespace gcpp { -// Setting this to false will load and use uncompressed weights. -constexpr bool kWeightsAreCompressed = true; - // ---------------------------------------------------------------------------- // Uncompressed @@ -213,11 +210,16 @@ struct CompressedLayerPointers { template struct CompressedWeights { // No ctor/dtor, allocated via AllocateAligned. + using Weight = typename TConfig::Weight; - CompressedArray + using WeightF32OrInputT = + hwy::If(), float, EmbedderInputT>; + CompressedArray embedder_input_embedding; - CompressedArray final_norm_scale; + using WeightF32OrBF16 = + hwy::If(), float, hwy::bfloat16_t>; + CompressedArray final_norm_scale; // Must be last so that the other arrays remain aligned. CompressedLayerPointers c_layer_ptrs; @@ -233,10 +235,6 @@ struct CompressedWeights { // ---------------------------------------------------------------------------- // Interface -template -using WeightsT = hwy::If, - WeightsF>; - // TODO: can we use TConfig::Weight instead of T? template struct AllocateWeights { @@ -256,6 +254,17 @@ struct AllocateWeightsF { } }; +template +struct AllocateCompressedWeights { + ByteStorageT operator()(hwy::ThreadPool& pool) const { + using TWeights = CompressedWeights; + ByteStorageT weights_u8 = AllocateSizeof(); + TWeights* weights = reinterpret_cast(weights_u8.get()); + new (&weights->c_layer_ptrs) CompressedLayerPointers(pool); + return weights_u8; + } +}; + template struct ZeroInitWeights { void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { @@ -277,6 +286,20 @@ struct ZeroInitWeightsF { } }; +template +struct ZeroInitCompressedWeights { + void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { + CompressedWeights& w = + *reinterpret_cast*>(weights.get()); + hwy::ZeroBytes(&w.embedder_input_embedding, + sizeof(w.embedder_input_embedding)); + hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); + } + } +}; + template struct CopyWeights { void operator()(Weights& dst, @@ -295,12 +318,9 @@ void operator()(Weights& dst, template struct DeleteLayersPtrs { void operator()(ByteStorageT& weights_u8) const { - auto* weights = reinterpret_cast*>(weights_u8.get()); - if constexpr (kWeightsAreCompressed) { - weights->c_layer_ptrs.~CompressedLayerPointers(); - } else { - weights->layer_ptrs.~LayerPointers(); - } + auto* weights = + reinterpret_cast*>(weights_u8.get()); + weights->c_layer_ptrs.~CompressedLayerPointers(); } }; @@ -330,14 +350,8 @@ class WeightsWrapper { Weights* weights_; }; -// For use by compress_weights.cc. -ByteStorageT LoadRawWeights(const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool, - bool scale_for_compression); - -// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed. -ByteStorageT LoadWeights(const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool); +ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, + Type weight_type, hwy::ThreadPool& pool); void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); @@ -467,62 +481,62 @@ void ForEachTensor(const WeightsF* weights, GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ } -template -void ForEachTensor1(Func& func, const Weights& weights1) { +template +void ForEachTensor1(Func& func, const CompressedWeights& weights1) { GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - const LayerF& layer1 = *weights1.GetLayer(idx); + const CompressedLayer& layer1 = *weights1.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(1) } } -template -void ForEachTensor1(Func& func, Weights& weights1) { +template +void ForEachTensor1(Func& func, CompressedWeights& weights1) { GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - LayerF& layer1 = *weights1.GetLayer(idx); + CompressedLayer& layer1 = *weights1.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(1) } } -template -void ForEachTensor2(Func& func, const Weights& weights1, - Weights& weights2) { +template +void ForEachTensor2(Func& func, const CompressedWeights& weights1, + CompressedWeights& weights2) { GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - const LayerF& layer1 = *weights1.GetLayer(idx); - LayerF& layer2 = *weights2.GetLayer(idx); + const CompressedLayer& layer1 = *weights1.GetLayer(idx); + CompressedLayer& layer2 = *weights2.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(2) } } -template -void ForEachTensor4(Func& func, const Weights& weights1, - Weights& weights2, - Weights& weights3, - Weights& weights4) { +template +void ForEachTensor4(Func& func, const CompressedWeights& weights1, + CompressedWeights& weights2, + CompressedWeights& weights3, + CompressedWeights& weights4) { GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - const LayerF& layer1 = *weights1.GetLayer(idx); - LayerF& layer2 = *weights2.GetLayer(idx); - LayerF& layer3 = *weights3.GetLayer(idx); - LayerF& layer4 = *weights4.GetLayer(idx); + const CompressedLayer& layer1 = *weights1.GetLayer(idx); + CompressedLayer& layer2 = *weights2.GetLayer(idx); + CompressedLayer& layer3 = *weights3.GetLayer(idx); + CompressedLayer& layer4 = *weights4.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(4) } } From 6ca4a8e345e3babd981e6e172214818fa150f25a Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Mon, 10 Jun 2024 15:27:22 +0000 Subject: [PATCH 2/2] Address review comments --- backprop/optimizer.cc | 21 +++++++++++++-------- compression/compress.h | 3 --- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 93d3164..f004446 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -33,8 +33,9 @@ class WeightInitializer { template void operator()(const char* name, CompressedArray& tensor) { + float* data = tensor.data(); for (size_t i = 0; i < N; ++i) { - tensor[i] = dist_(gen_); + data[i] = dist_(gen_); } tensor.set_scale(1.0f); } @@ -70,14 +71,18 @@ class AdamUpdater { CompressedArray& weights, CompressedArray& grad_m, CompressedArray& grad_v) { + const float* HWY_RESTRICT g = grad.data(); + float* HWY_RESTRICT w = weights.data(); + float* HWY_RESTRICT m = grad_m.data(); + float* HWY_RESTRICT v = grad_v.data(); for (size_t i = 0; i < kCapacity; ++i) { - grad_m[i] *= beta1_; - grad_m[i] += cbeta1_ * grad[i]; - grad_v[i] *= beta2_; - grad_v[i] += cbeta2_ * grad[i] * grad[i]; - const float mhat = grad_m[i] * norm1_; - const float vhat = grad_v[i] * norm2_; - weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); + m[i] *= beta1_; + m[i] += cbeta1_ * g[i]; + v[i] *= beta2_; + v[i] += cbeta2_ * g[i] * g[i]; + const float mhat = m[i] * norm1_; + const float vhat = v[i] * norm2_; + w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); } } diff --git a/compression/compress.h b/compression/compress.h index 344cabc..edb7fdb 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -79,9 +79,6 @@ class CompressedArray { MatT* data() { return data_.data(); } const MatT* data() const { return data_.data(); } - MatT& operator[](size_t pos) { return data_[pos]; } - const MatT& operator[](size_t pos) const { return data_[pos]; } - float scale() const { return scale_[0]; } void set_scale(float scale) { scale_[0] = scale; }