From 0d68555f875d3b34d29e8c5c7290e10a8fae3609 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 17 Oct 2024 05:03:35 -0700 Subject: [PATCH] Eliminated TConfig. Changed CompressedLayer and CompressedWeights to be constructed with an instance of a LayerConfig and WeightsConfig respectively. Added CompressedModel to remove ByteStorageT and get rid of most of the type casting, as well as allowing the default destructor to be used and work properly. Adjusted WeightsWrapper and ForwardLayer etc to match. The only remaining template arg is the weight type. This enables all the instantiations to be deleted, apart from one per type. It also enables (but not yet done) the config to be stored in the blob file instead of having to be specified separately. Reduces the size of the gemma_lib and weights shared libraries by a factor of 4.3 and 3.2 respectively. PiperOrigin-RevId: 686870060 --- BUILD.bazel | 58 +- CMakeLists.txt | 29 +- backprop/activations.h | 93 +- backprop/backward-inl.h | 258 ++--- backprop/backward.cc | 54 +- backprop/backward.h | 14 +- backprop/backward_scalar.h | 176 ++- backprop/backward_scalar_test.cc | 137 +-- backprop/backward_test.cc | 90 +- backprop/forward-inl.h | 132 +-- backprop/forward.cc | 46 +- backprop/forward.h | 11 +- backprop/forward_scalar.h | 154 ++- backprop/optimize_test.cc | 58 +- backprop/optimizer.cc | 85 +- backprop/optimizer.h | 17 +- backprop/test_util.h | 66 +- compression/blob_store.cc | 2 +- compression/compress.h | 47 +- compression/compress_weights.cc | 59 +- compression/shared.h | 61 +- evals/benchmark.cc | 4 +- evals/benchmark_helper.cc | 6 +- evals/cross_entropy.cc | 3 +- examples/hello_world/run.cc | 4 +- gemma/activations.h | 82 +- gemma/common.cc | 19 +- gemma/common.h | 212 +--- gemma/configs.cc | 246 ++++ gemma/configs.h | 419 ++----- gemma/configs_test.cc | 445 ++++++++ gemma/gemma-inl.h | 1117 ++++++++++--------- gemma/gemma.cc | 92 +- gemma/gemma.h | 20 +- gemma/instantiations/27b_bf16.cc | 21 - gemma/instantiations/27b_f32.cc | 21 - gemma/instantiations/27b_sfp.cc | 21 - gemma/instantiations/2b_bf16.cc | 21 - gemma/instantiations/7b_bf16.cc | 21 - gemma/instantiations/7b_sfp.cc | 21 - gemma/instantiations/9b_bf16.cc | 21 - gemma/instantiations/9b_sfp.cc | 21 - gemma/instantiations/{2b_f32.cc => bf16.cc} | 5 +- gemma/instantiations/{7b_f32.cc => f32.cc} | 5 +- gemma/instantiations/gemma2_2b_bf16.cc | 21 - gemma/instantiations/gemma2_2b_f32.cc | 21 - gemma/instantiations/gemma2_2b_sfp.cc | 21 - gemma/instantiations/gr2b_bf16.cc | 21 - gemma/instantiations/gr2b_f32.cc | 21 - gemma/instantiations/gr2b_sfp.cc | 21 - gemma/instantiations/{9b_f32.cc => nuq.cc} | 5 +- gemma/instantiations/paligemma_224_bf16.cc | 21 - gemma/instantiations/paligemma_224_f32.cc | 21 - gemma/instantiations/paligemma_224_sfp.cc | 21 - gemma/instantiations/{2b_sfp.cc => sfp.cc} | 5 +- gemma/instantiations/tiny_bf16.cc | 21 - gemma/instantiations/tiny_f32.cc | 21 - gemma/instantiations/tiny_sfp.cc | 21 - gemma/kv_cache.cc | 87 +- gemma/kv_cache.h | 3 +- gemma/run.cc | 2 +- gemma/weights.cc | 224 ++-- gemma/weights.h | 459 ++++---- ops/gemma_matvec_test.cc | 14 +- ops/matvec-inl.h | 118 +- ops/ops-inl.h | 21 +- ops/ops_test.cc | 26 +- util/app.h | 2 + 68 files changed, 2810 insertions(+), 2902 deletions(-) create mode 100644 gemma/configs.cc create mode 100644 gemma/configs_test.cc delete mode 100644 gemma/instantiations/27b_bf16.cc delete mode 100644 gemma/instantiations/27b_f32.cc delete mode 100644 gemma/instantiations/27b_sfp.cc delete mode 100644 gemma/instantiations/2b_bf16.cc delete mode 100644 gemma/instantiations/7b_bf16.cc delete mode 100644 gemma/instantiations/7b_sfp.cc delete mode 100644 gemma/instantiations/9b_bf16.cc delete mode 100644 gemma/instantiations/9b_sfp.cc rename gemma/instantiations/{2b_f32.cc => bf16.cc} (87%) rename gemma/instantiations/{7b_f32.cc => f32.cc} (87%) delete mode 100644 gemma/instantiations/gemma2_2b_bf16.cc delete mode 100644 gemma/instantiations/gemma2_2b_f32.cc delete mode 100644 gemma/instantiations/gemma2_2b_sfp.cc delete mode 100644 gemma/instantiations/gr2b_bf16.cc delete mode 100644 gemma/instantiations/gr2b_f32.cc delete mode 100644 gemma/instantiations/gr2b_sfp.cc rename gemma/instantiations/{9b_f32.cc => nuq.cc} (87%) delete mode 100644 gemma/instantiations/paligemma_224_bf16.cc delete mode 100644 gemma/instantiations/paligemma_224_f32.cc delete mode 100644 gemma/instantiations/paligemma_224_sfp.cc rename gemma/instantiations/{2b_sfp.cc => sfp.cc} (87%) delete mode 100644 gemma/instantiations/tiny_bf16.cc delete mode 100644 gemma/instantiations/tiny_f32.cc delete mode 100644 gemma/instantiations/tiny_sfp.cc diff --git a/BUILD.bazel b/BUILD.bazel index 1e0cc73..c480f23 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -104,8 +104,6 @@ cc_test( tags = ["hwy_ops_test"], deps = [ ":allocator", - ":common", - ":gemma_lib", ":ops", ":test_util", ":threading", @@ -183,7 +181,10 @@ cc_test( cc_library( name = "common", - srcs = ["gemma/common.cc"], + srcs = [ + "gemma/common.cc", + "gemma/configs.cc", + ], hdrs = [ "gemma/common.h", "gemma/configs.h", @@ -195,12 +196,20 @@ cc_library( ], ) +cc_test( + name = "configs_test", + srcs = ["gemma/configs_test.cc"], + deps = [ + ":common", + "@googletest//:gtest_main", + ], +) + cc_library( name = "weights", srcs = ["gemma/weights.cc"], hdrs = ["gemma/weights.h"], deps = [ - ":allocator", ":common", "//compression:compress", "//compression:io", @@ -219,7 +228,6 @@ cc_library( ":common", "//compression:io", "@highway//:hwy", - "@highway//:nanobenchmark", # timer "@highway//:profiler", "@com_google_sentencepiece//:sentencepiece_processor", ], @@ -239,30 +247,10 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/gemma.cc", - "gemma/instantiations/27b_bf16.cc", - "gemma/instantiations/27b_f32.cc", - "gemma/instantiations/27b_sfp.cc", - "gemma/instantiations/2b_bf16.cc", - "gemma/instantiations/2b_f32.cc", - "gemma/instantiations/2b_sfp.cc", - "gemma/instantiations/7b_bf16.cc", - "gemma/instantiations/7b_f32.cc", - "gemma/instantiations/7b_sfp.cc", - "gemma/instantiations/9b_bf16.cc", - "gemma/instantiations/9b_f32.cc", - "gemma/instantiations/9b_sfp.cc", - "gemma/instantiations/tiny_bf16.cc", - "gemma/instantiations/tiny_f32.cc", - "gemma/instantiations/tiny_sfp.cc", - "gemma/instantiations/gr2b_bf16.cc", - "gemma/instantiations/gr2b_f32.cc", - "gemma/instantiations/gr2b_sfp.cc", - "gemma/instantiations/gemma2_2b_bf16.cc", - "gemma/instantiations/gemma2_2b_f32.cc", - "gemma/instantiations/gemma2_2b_sfp.cc", - "gemma/instantiations/paligemma_224_bf16.cc", - "gemma/instantiations/paligemma_224_f32.cc", - "gemma/instantiations/paligemma_224_sfp.cc", + "gemma/instantiations/bf16.cc", + "gemma/instantiations/f32.cc", + "gemma/instantiations/nuq.cc", + "gemma/instantiations/sfp.cc", ], hdrs = [ "gemma/activations.h", @@ -327,8 +315,6 @@ cc_library( ":threading", "//compression:io", "@highway//:hwy", - "@highway//:thread_pool", - "@highway//:topology", ], ) @@ -367,7 +353,6 @@ cc_test( ":benchmark_helper", ":common", ":gemma_lib", - ":tokenizer", "@googletest//:gtest_main", "@highway//:hwy", "@highway//:hwy_test_util", @@ -396,7 +381,6 @@ cc_binary( name = "single_benchmark", srcs = ["evals/benchmark.cc"], deps = [ - ":app", ":args", ":benchmark_helper", ":common", @@ -405,7 +389,6 @@ cc_binary( "//compression:io", "@highway//:hwy", "@highway//:nanobenchmark", - "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -429,13 +412,11 @@ cc_binary( "evals/debug_prompt.cc", ], deps = [ - ":app", ":args", ":benchmark_helper", ":gemma_lib", "//compression:io", "@highway//:hwy", - "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -444,7 +425,6 @@ cc_binary( name = "gemma_mmlu", srcs = ["evals/run_mmlu.cc"], deps = [ - ":app", ":args", ":benchmark_helper", ":gemma_lib", @@ -488,7 +468,6 @@ cc_library( deps = [ ":allocator", ":common", - ":gemma_lib", ":ops", ":prompt", ":weights", @@ -508,7 +487,6 @@ cc_library( "backprop/forward_scalar.h", ], deps = [ - ":allocator", ":common", ":prompt", ":weights", @@ -525,7 +503,6 @@ cc_test( "backprop/test_util.h", ], deps = [ - ":allocator", ":backprop_scalar", ":common", ":prompt", @@ -599,6 +576,7 @@ cc_test( ":threading", ":weights", "@googletest//:gtest_main", + "//compression:sfp", "@highway//:thread_pool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 51ab2e4..bade5de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,34 +68,15 @@ set(SOURCES gemma/activations.h gemma/common.cc gemma/common.h + gemma/configs.cc gemma/configs.h gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h - gemma/instantiations/27b_bf16.cc - gemma/instantiations/27b_f32.cc - gemma/instantiations/27b_sfp.cc - gemma/instantiations/2b_bf16.cc - gemma/instantiations/2b_f32.cc - gemma/instantiations/2b_sfp.cc - gemma/instantiations/7b_bf16.cc - gemma/instantiations/7b_f32.cc - gemma/instantiations/7b_sfp.cc - gemma/instantiations/9b_bf16.cc - gemma/instantiations/9b_f32.cc - gemma/instantiations/9b_sfp.cc - gemma/instantiations/gr2b_bf16.cc - gemma/instantiations/gr2b_f32.cc - gemma/instantiations/gr2b_sfp.cc - gemma/instantiations/tiny_bf16.cc - gemma/instantiations/tiny_f32.cc - gemma/instantiations/tiny_sfp.cc - gemma/instantiations/gemma2_2b_bf16.cc - gemma/instantiations/gemma2_2b_f32.cc - gemma/instantiations/gemma2_2b_sfp.cc - gemma/instantiations/paligemma_224_bf16.cc - gemma/instantiations/paligemma_224_f32.cc - gemma/instantiations/paligemma_224_sfp.cc + gemma/instantiations/bf16.cc + gemma/instantiations/f32.cc + gemma/instantiations/nuq.cc + gemma/instantiations/sfp.cc gemma/kv_cache.cc gemma/kv_cache.h gemma/tokenizer.cc diff --git a/backprop/activations.h b/backprop/activations.h index 4f2e821..c616759 100644 --- a/backprop/activations.h +++ b/backprop/activations.h @@ -18,32 +18,27 @@ #include -#include +#include #include "compression/compress.h" // MatStorageT -#include "util/allocator.h" // ByteStorageT +#include "gemma/configs.h" // ModelConfig namespace gcpp { -template +template struct ForwardLayer { - ForwardLayer() - : input("input", kSeqLen, kModelDim), - pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim), - qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim), - att("att", kSeqLen * kHeads, kSeqLen), - att_out("att_out", kSeqLen * kHeads, kQKVDim), - att_post1("att_post1", kSeqLen, kModelDim), - attention_out("attention_out", kSeqLen, kModelDim), - bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim), - ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2), - ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {} - - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + ForwardLayer(const LayerConfig& config, size_t seq_len) + : input("input", seq_len, config.model_dim), + pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim), + qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim), + att("att", seq_len * config.heads, seq_len), + att_out("att_out", seq_len * config.heads, config.qkv_dim), + att_post1("att_post1", seq_len, config.model_dim), + attention_out("attention_out", seq_len, config.model_dim), + bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim), + ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2), + ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim), + layer_config(config) {} MatStorageT input; MatStorageT pre_att_rms_out; @@ -55,56 +50,30 @@ struct ForwardLayer { MatStorageT bf_pre_ffw_rms_out; MatStorageT ffw_hidden; MatStorageT ffw_hidden_gated; + const LayerConfig& layer_config; }; -template +template struct ForwardPass { - ForwardPass() - : final_layer_output("final_layer_output", kSeqLen, kModelDim), - final_norm_output("final_norm_output", kSeqLen, kModelDim), - logits("logits", kSeqLen, kVocabSize), - probs("probs", kSeqLen, kVocabSize) { - } // prevents placement-new calling memset + ForwardPass(const ModelConfig& config) + : final_layer_output("final_layer_output", config.seq_len, + config.model_dim), + final_norm_output("final_norm_output", config.seq_len, + config.model_dim), + logits("logits", config.seq_len, config.vocab_size), + probs("probs", config.seq_len, config.vocab_size), + weights_config(config) { + for (const auto& layer_config : config.layer_configs) { + layers.emplace_back(layer_config, config.seq_len); + } + } - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kLayers = TConfig::kLayers; - - std::array, kLayers> layers; + std::vector> layers; MatStorageT final_layer_output; MatStorageT final_norm_output; MatStorageT logits; MatStorageT probs; -}; - -template -struct AllocateForwardPass { - ByteStorageT operator()() const { - ByteStorageT c_weights_u8 = AllocateSizeof>(); - auto* c_weights = - reinterpret_cast*>(c_weights_u8.get()); - new (c_weights) ForwardPass(); - return c_weights_u8; - } -}; - -// Owns activations and undoes the type erasure of AllocateAligned. -template -class ActivationsWrapper { - using WrappedT = ForwardPass; - - public: - ActivationsWrapper() - : data_(AllocateSizeof()), - activations_(*(new(data_.get()) WrappedT())) {} - - const WrappedT& get() const { return activations_; } - WrappedT& get() { return activations_; } - - private: - ByteStorageT data_; - WrappedT& activations_; + const ModelConfig& weights_config; }; } // namespace gcpp diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index f765a5a..2a0f330 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -28,6 +28,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/common.h" +#include "gemma/weights.h" #include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -53,45 +54,41 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -template -void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, - const float* HWY_RESTRICT x, // num_tokens * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows - size_t num_tokens, - float* HWY_RESTRICT grad_w, // kRows * kCols, - float* HWY_RESTRICT grad_x, // num_tokens * kCols - hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); +HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, + const float* HWY_RESTRICT x, // num_tokens * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows + size_t cols, size_t rows, size_t num_tokens, + float* HWY_RESTRICT grad_w, // kRows * kCols, + float* HWY_RESTRICT grad_x, // num_tokens * kCols + hwy::ThreadPool& pool) { + hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t voffs = pos * kRows; - const size_t xoffs = pos * kCols; - for (size_t j = 0; j < kRows; ++j) { - MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols); - MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs], - kCols); + const size_t voffs = pos * rows; + const size_t xoffs = pos * cols; + for (size_t j = 0; j < rows; ++j) { + MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols); + MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols); } } } -template -void MultiHeadMatMulVJP( - const float* HWY_RESTRICT weights, // kHeads * kRows * kCols - const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols +HWY_INLINE void MultiHeadMatMulVJP( + const float* HWY_RESTRICT weights, // heads * kRows * kCols + const float* HWY_RESTRICT x, // num_tokens * heads * kCols const float* HWY_RESTRICT v, // num_tokens * kRows - size_t num_tokens, - float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols - float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols + size_t heads, size_t cols, size_t rows, size_t num_tokens, + float* HWY_RESTRICT grad_w, // heads * kRows * kCols + float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); + hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t j = 0; j < kRows; ++j) { - for (size_t h = 0; h < kHeads; ++h) { - MulByConstAndAdd(v[pos * kRows + j], - &x[pos * kHeads * kCols + h * kCols], - &grad_w[h * kRows * kCols + j * kCols], kCols); - MulByConstAndAdd(v[pos * kRows + j], - &weights[h * kRows * kCols + j * kCols], - &grad_x[pos * kHeads * kCols + h * kCols], kCols); + for (size_t j = 0; j < rows; ++j) { + for (size_t h = 0; h < heads; ++h) { + MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols], + &grad_w[h * rows * cols + j * cols], cols); + MulByConstAndAdd(v[pos * rows + j], + &weights[h * rows * cols + j * cols], + &grad_x[pos * heads * cols + h * cols], cols); } } } @@ -168,39 +165,39 @@ static HWY_NOINLINE void InputEmbeddingVJP( } } -template -void LayerVJP(const LayerT& weights, - const ForwardLayer& forward, +template +void LayerVJP(const LayerWeightsPtrs& weights, + const ForwardLayer& forward, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, - LayerT& grad, ForwardLayer& backward, + LayerWeightsPtrs& grad, ForwardLayer& backward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static const float kQueryScale = - static_cast(1.0 / sqrt(static_cast(kQKVDim))); - HWY_ASSERT(num_tokens <= kSeqLen); + const LayerConfig& config = weights.layer_config; + const size_t model_dim = config.model_dim; + const size_t qkv_dim = config.qkv_dim; + const size_t heads = config.heads; + const size_t seq_len = forward.input.Rows(); + const size_t ff_hidden_dim = config.ff_hidden_dim; + const float query_scale = + static_cast(1.0 / sqrt(static_cast(qkv_dim))); + HWY_ASSERT(num_tokens <= seq_len); - MatMulVJP( - weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad, - num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(), - pool); + MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(), + next_layer_grad, ff_hidden_dim, model_dim, num_tokens, + grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t hidden_offset = pos * kFFHiddenDim * 2; + const size_t hidden_offset = pos * ff_hidden_dim * 2; const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; - const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim; + const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim; const float* HWY_RESTRICT b_out_gated = - backward.ffw_hidden_gated.data() + pos * kFFHiddenDim; + backward.ffw_hidden_gated.data() + pos * ff_hidden_dim; float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim; + float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; DF df; - for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { + for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) { const auto y = Load(df, f_out + i); const auto x = Load(df, f_out_mul + i); const auto v = Load(df, b_out_gated + i); @@ -209,101 +206,94 @@ void LayerVJP(const LayerT& weights, } } - MatMulVJP( - weights.gating_einsum_w.data(), - forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), - num_tokens, grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), pool); - RMSNormVJP(weights.pre_ffw_norm_scale.data(), - forward.attention_out.data(), - backward.bf_pre_ffw_rms_out.data(), - kModelDim, num_tokens, - grad.pre_ffw_norm_scale.data(), - backward.attention_out.data(), pool); + MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), + backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2, + num_tokens, grad.gating_einsum_w.data(), + backward.bf_pre_ffw_rms_out.data(), pool); + RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), + backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens, + grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), + pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(next_layer_grad + pos * kModelDim, - backward.attention_out.data() + pos * kModelDim, kModelDim); + AddFrom(next_layer_grad + pos * model_dim, + backward.attention_out.data() + pos * model_dim, model_dim); } backward.qkv.ZeroInit(); - MultiHeadMatMulVJP( - weights.attn_vec_einsum_w.data(), forward.att_out.data(), - backward.attention_out.data(), num_tokens, - grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool); + MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(), + backward.attention_out.data(), heads, qkv_dim, model_dim, + num_tokens, grad.attn_vec_einsum_w.data(), + backward.att_out.data(), pool); - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t aoffset = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; const float* HWY_RESTRICT b_att_out = - backward.att_out.data() + (pos * kHeads + head) * kQKVDim; + backward.att_out.data() + (pos * heads + head) * qkv_dim; float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim; const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; - b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim); - MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim); + b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim); + MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim); } } } - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t aoffset = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; SoftmaxVJP(f_head_att, b_head_att, pos + 1); } } - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; - const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim; + const size_t aoffs = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim; const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; - MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim); - MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim); + MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim); + MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim); } } } for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = - backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos); + backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim; + Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); } - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT b_q = - backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; - MulByConst(kQueryScale, b_q, kQKVDim); - Rope(b_q, kQKVDim, inv_timescale.Const(), -pos); + backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim; + MulByConst(query_scale, b_q, qkv_dim); + Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); } } - MatMulVJP( - weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), - backward.qkv.data(), num_tokens, - grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); - RMSNormVJP(weights.pre_attention_norm_scale.data(), - forward.input.data(), - backward.pre_att_rms_out.data(), - kModelDim, num_tokens, - grad.pre_attention_norm_scale.data(), - backward.input.data(), pool); + MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), + backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens, + grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); + RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(), + backward.pre_att_rms_out.data(), model_dim, num_tokens, + grad.pre_attention_norm_scale.data(), backward.input.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(backward.attention_out.data() + pos * kModelDim, - backward.input.data() + pos * kModelDim, kModelDim); + AddFrom(backward.attention_out.data() + pos * model_dim, + backward.input.data() + pos * model_dim, model_dim); } } @@ -342,20 +332,22 @@ static HWY_NOINLINE void CrossEntropyLossGrad( } } -template -void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, - const ForwardPass& forward, - WeightsT& grad, - ForwardPass& backward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kLayers = TConfig::kLayers; - const float kEmbScaling = EmbeddingScaling(); - static_assert(!TConfig::kAbsolutePE); - static_assert(TConfig::kPostNorm == PostNormType::None); - static_assert(TConfig::kKVHeads == 1); +template +void CrossEntropyLossBackwardPassInl(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, + RowVectorBatch& inv_timescale, + hwy::ThreadPool& pool) { + const ModelConfig& config = weights.weights_config; + const size_t kVocabSize = config.vocab_size; + const size_t model_dim = config.model_dim; + const size_t kLayers = config.layer_configs.size(); + const float kEmbScaling = EmbeddingScaling(model_dim); + HWY_ASSERT(!config.absolute_pe); + HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); + HWY_ASSERT(config.layer_configs[0].kv_heads == 1); HWY_DASSERT(prompt.context_size > 0); HWY_DASSERT(prompt.context_size < prompt.tokens.size()); @@ -370,42 +362,38 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, kVocabSize); } - if constexpr (TConfig::kFinalCap > 0.0f) { + if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, + SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize, backward.logits.data() + pos * kVocabSize, kVocabSize); } } - MatMulVJP( - weights.embedder_input_embedding.data(), forward.final_norm_output.data(), - backward.logits.data(), num_tokens, - grad.embedder_input_embedding.data(), backward.final_norm_output.data(), - pool); + MatMulVJP(weights.embedder_input_embedding.data(), + forward.final_norm_output.data(), backward.logits.data(), model_dim, + kVocabSize, num_tokens, grad.embedder_input_embedding.data(), + backward.final_norm_output.data(), pool); - RMSNormVJP(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - backward.final_norm_output.data(), - kModelDim, num_tokens, - grad.final_norm_scale.data(), - backward.final_layer_output.data(), pool); + RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(), + backward.final_norm_output.data(), model_dim, num_tokens, + grad.final_norm_scale.data(), backward.final_layer_output.data(), + pool); for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { - auto type = TConfig::kLayerConfig[layer]; + auto layer_config = config.layer_configs[layer]; // TODO(szabadka) Implement Griffin layer vjp. - HWY_ASSERT(type == LayerAttentionType::kGemma); + HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma); float* next_layer_grad = layer + 1 < kLayers ? backward.layers[layer + 1].input.data() : backward.final_layer_output.data(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], - next_layer_grad, num_tokens, - *grad.GetLayer(layer), backward.layers[layer], - inv_timescale, pool); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + num_tokens, *grad.GetLayer(layer), backward.layers[layer], + inv_timescale, pool); } InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), kModelDim); + grad.embedder_input_embedding.data(), model_dim); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/backprop/backward.cc b/backprop/backward.cc index c186952..868b391 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -38,44 +38,15 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ByteStorageT& weights_u8, - const ByteStorageT& forward_u8, - ByteStorageT& grad_u8, - ByteStorageT& backward_u8, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - using TWeights = CompressedWeights; - const auto& weights = *reinterpret_cast(weights_u8.get()); - auto& grad = *reinterpret_cast(grad_u8.get()); - using TAct = ForwardPass; - const auto& forward = *reinterpret_cast(forward_u8.get()); - auto& backward = *reinterpret_cast(backward_u8.get()); - CrossEntropyLossBackwardPass, - CompressedLayer>( - prompt, weights, forward, grad, backward, inv_timescale, pool); -} - -void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt, - const ByteStorageT& weights, - const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, +void CrossEntropyLossBackwardPassT(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - // TODO(janwas): use CallFunctorForModel - switch (model) { - case Model::GEMMA_2B: - CrossEntropyLossBackwardPass>( - prompt, weights, forward, grad, backward, inv_timescale, pool); - break; - case Model::GEMMA_TINY: - CrossEntropyLossBackwardPass>( - prompt, weights, forward, grad, backward, inv_timescale, pool); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } + CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward, + inv_timescale, pool); } } // namespace HWY_NAMESPACE @@ -87,14 +58,15 @@ namespace gcpp { HWY_EXPORT(CrossEntropyLossBackwardPassT); -void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( - model, prompt, weights, forward, grad, backward, inv_timescale, pool); + prompt, weights, forward, grad, backward, inv_timescale, pool); } } // namespace gcpp diff --git a/backprop/backward.h b/backprop/backward.h index 0ac218a..d8e50c7 100644 --- a/backprop/backward.h +++ b/backprop/backward.h @@ -16,17 +16,19 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" -#include "gemma/common.h" +#include "gemma/weights.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool); diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index a804cd3..b0a37b3 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -125,65 +125,64 @@ void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) { } } - -template +template void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, - size_t num_tokens, size_t kHeads, size_t kQKVDim, - size_t kSeqLen) { + size_t num_tokens, size_t kHeads, size_t qkv_dim, + size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * (kHeads + 2) * kQKVDim; - memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0])); + const size_t offset = pos * (kHeads + 2) * qkv_dim; + memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0])); } for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; - const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim; + const size_t aoffs = head * seq_len + pos * kHeads * seq_len; const T* q = qkv + qoffs; const T* dout = doutput + aoffs; T* dq = dqkv + qoffs; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim; const T* k = qkv + koffs; T* dk = dqkv + koffs; - MulByConstAndAddT(dout[pos2], k, dq, kQKVDim); - MulByConstAndAddT(dout[pos2], q, dk, kQKVDim); + MulByConstAndAddT(dout[pos2], k, dq, qkv_dim); + MulByConstAndAddT(dout[pos2], q, dk, qkv_dim); } } } } -template -void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, - size_t kHeads, size_t kSeqLen) { +template +void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads, + size_t seq_len) { for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + size_t offset = pos * kHeads * seq_len + head * seq_len; SoftmaxVJPT(y + offset, dy + offset, pos + 1); - memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); } } } -template +template void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, - T* dqkv, T* dattention, size_t num_tokens, - size_t kHeads, size_t kQKVDim, size_t kSeqLen) { + T* dqkv, T* dattention, size_t num_tokens, size_t kHeads, + size_t qkv_dim, size_t seq_len) { auto v_offset = [&](size_t pos) { - return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim; + return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim; }; for (size_t pos = 0; pos < num_tokens; ++pos) { - memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0])); + memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0])); } for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim; - const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim; + const size_t aoffset = head * seq_len + pos * kHeads * seq_len; const T* att = &attention[aoffset]; const T* dout = &doutput[offset]; T* datt = &dattention[aoffset]; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim); - MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim); + datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim); + MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim); } } } @@ -199,77 +198,76 @@ void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, } } -template -void LayerVJP(const CompressedLayer& weights, - const ForwardLayer& forward, const T* dy, - CompressedLayer& grad, - ForwardLayer& backward, size_t num_tokens) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim)); +template +void LayerVJP(const LayerWeightsPtrs& weights, + const ForwardLayer& forward, const T* dy, + LayerWeightsPtrs& grad, ForwardLayer& backward, + size_t num_tokens) { + const LayerConfig& layer_config = weights.layer_config; + const size_t model_dim = layer_config.model_dim; + const size_t seq_len = forward.input.Rows(); + const size_t qkv_dim = layer_config.qkv_dim; + const size_t kHeads = layer_config.heads; + const size_t kFFHiddenDim = layer_config.ff_hidden_dim; + const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim)); - MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), - dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(), - kModelDim, kFFHiddenDim, num_tokens); + MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy, + grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim, + kFFHiddenDim, num_tokens); GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim, + backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim, num_tokens); RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), backward.bf_pre_ffw_rms_out.data(), grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), - kModelDim, num_tokens); + model_dim, num_tokens); - AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim); + AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim); MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), backward.attention_out.data(), - grad.attn_vec_einsum_w.data(), - backward.att_out.data(), - kHeads, kModelDim, kQKVDim, num_tokens); + grad.attn_vec_einsum_w.data(), backward.att_out.data(), + kHeads, model_dim, qkv_dim, num_tokens); MixByAttentionVJP(forward.qkv.data(), forward.att.data(), backward.att_out.data(), backward.qkv.data(), - backward.att.data(), num_tokens, kHeads, kQKVDim, - kSeqLen); + backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len); - MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), - num_tokens, kHeads, kSeqLen); + MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads, + seq_len); MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), - backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen); + backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len); for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; - MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; + MulByConstT(kQueryScale, qkv, kHeads * qkv_dim); } for (int pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; for (size_t h = 0; h <= kHeads; ++h) { - Rope(qkv + h * kQKVDim, kQKVDim, -pos); + Rope(qkv + h * qkv_dim, qkv_dim, -pos); } } MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), backward.qkv.data(), grad.qkv_einsum_w.data(), - backward.pre_att_rms_out.data(), - (kHeads + 2) * kQKVDim, kModelDim, num_tokens); + backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim, + num_tokens); RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), backward.pre_att_rms_out.data(), - grad.pre_attention_norm_scale.data(), - backward.input.data(), kModelDim, num_tokens); + grad.pre_attention_norm_scale.data(), backward.input.data(), + model_dim, num_tokens); AddFromT(backward.attention_out.data(), backward.input.data(), - num_tokens * kModelDim); + num_tokens * model_dim); } template @@ -296,56 +294,54 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { } } -template +template void CrossEntropyLossBackwardPass(const Prompt& prompt, - const CompressedWeights& weights, - const ForwardPass& forward, - CompressedWeights& grad, - ForwardPass& backward) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kLayers = TConfig::kLayers; + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward) { + const ModelConfig& config = weights.weights_config; + const size_t model_dim = config.model_dim; + const size_t vocab_size = config.vocab_size; + const size_t layers = config.layer_configs.size(); const std::vector tokens = prompt.tokens; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, - kVocabSize); + vocab_size); - SoftmaxVJPT(forward.probs.data(), backward.logits.data(), - kVocabSize, num_tokens); + SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size, + num_tokens); - if constexpr (TConfig::kFinalCap > 0.0f) { + if (config.final_cap > 0.0f) { for (size_t i = 0; i < num_tokens; ++i) { - SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize, - backward.logits.data() + i * kVocabSize, kVocabSize); + SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size, + backward.logits.data() + i * vocab_size, vocab_size); } } - MatMulVJPT(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), - backward.logits.data(), - grad.embedder_input_embedding.data(), - backward.final_norm_output.data(), - kVocabSize, kModelDim, num_tokens); + MatMulVJPT( + weights.embedder_input_embedding.data(), forward.final_norm_output.data(), + backward.logits.data(), grad.embedder_input_embedding.data(), + backward.final_norm_output.data(), vocab_size, model_dim, num_tokens); RMSNormVJPT(weights.final_norm_scale.data(), forward.final_layer_output.data(), - backward.final_norm_output.data(), - grad.final_norm_scale.data(), - backward.final_layer_output.data(), kModelDim, num_tokens); + backward.final_norm_output.data(), grad.final_norm_scale.data(), + backward.final_layer_output.data(), model_dim, num_tokens); - for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { - T* next_layer_grad = layer + 1 < kLayers - ? backward.layers[layer + 1].input.data() - : backward.final_layer_output.data(); + for (int layer = static_cast(layers) - 1; layer >= 0; --layer) { + T* next_layer_grad = layer + 1 < layers + ? backward.layers[layer + 1].input.data() + : backward.final_layer_output.data(); LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, *grad.GetLayer(layer), backward.layers[layer], num_tokens); } - const T kEmbScaling = EmbeddingScaling(kModelDim); - InputEmbeddingVJPT(weights.embedder_input_embedding.data(), - tokens, kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), kModelDim); + const T kEmbScaling = EmbeddingScaling(model_dim); + InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens, + kEmbScaling, backward.layers[0].input.data(), + grad.embedder_input_embedding.data(), model_dim); } } // namespace gcpp diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 262a121..b5e39db 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -19,7 +19,6 @@ #include #include // memcpy -#include #include #include #include @@ -384,44 +383,49 @@ TEST(BackPropTest, InputEmbeddingVJP) { } } -template -struct TestConfig : ConfigBaseGemmaV2 { - using Weight = T; - static constexpr int kSeqLen = 18; - static constexpr int kVocabSize = 12; - static constexpr int kModelDim = 32; - static constexpr int kHeads = 3; - static constexpr int kQKVDim = 12; - static constexpr int kFFHiddenDim = 48; - static constexpr std::array kLayerConfig = - FixedLayerConfig<2>(LayerAttentionType::kGemma); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - static constexpr int kKVHeads = 1; - static constexpr int kGemmaLayers = kLayers; -}; +static ModelConfig TestConfig() { + ModelConfig config; + config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", + "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; + config.model_dim = 32; + config.vocab_size = 12; + config.seq_len = 18; + LayerConfig layer_config = { + .model_dim = config.model_dim, + .ff_hidden_dim = 48, + .heads = 3, + .kv_heads = 1, + .qkv_dim = 12, + }; + config.layer_configs = {2, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); + // This is required for optimize_test to pass. + config.final_cap = 30.0f; + return config; +} TEST(BackPropTest, LayerVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim; - CompressedLayer> weights; - CompressedLayer> grad; - ForwardLayer> forward; - ForwardLayer> backward = {}; - CompressedLayer> c_weights; - ForwardLayer> c_forward; - std::array y; + ModelConfig config = TestConfig(); + const size_t kOutputSize = config.seq_len * config.model_dim; + LayerWeightsPtrs weights(config.layer_configs[0]); + LayerWeightsPtrs grad(config.layer_configs[0]); + ForwardLayer forward(config.layer_configs[0], config.seq_len); + ForwardLayer backward(config.layer_configs[0], config.seq_len); + LayerWeightsPtrs c_weights(config.layer_configs[0]); + ForwardLayer c_forward(config.layer_configs[0], config.seq_len); + MatStorageT y("y", kOutputSize, 1); MatStorageT dy("dy", kOutputSize, 1); - std::array c_y; + MatStorageT c_y("c_y", kOutputSize, 1); const size_t num_tokens = 3; - weights.Allocate(); - grad.Allocate(); - c_weights.Allocate(); + std::vector layer_storage; + weights.Allocate(layer_storage); + grad.Allocate(layer_storage); + c_weights.Allocate(layer_storage); backward.input.ZeroInit(); for (size_t iter = 0; iter < 10; ++iter) { @@ -432,7 +436,7 @@ TEST(BackPropTest, LayerVJP) { Complexify(forward.input, c_forward.input); auto func = [&]() { ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); - return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim); + return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim); }; grad.ZeroInit(/*layer_idx=*/0); ApplyLayer(weights, forward, num_tokens, y.data()); @@ -447,12 +451,13 @@ TEST(BackPropTest, EndToEnd) { std::mt19937 gen(42); using T = double; using TC = std::complex; - WeightsWrapper> weights; - WeightsWrapper> grad; - ForwardPass> forward; - ForwardPass> backward; - WeightsWrapper> c_weights; - ForwardPass> c_forward; + ModelConfig config = TestConfig(); + WeightsWrapper weights(config); + WeightsWrapper grad(config); + ForwardPass forward(config); + ForwardPass backward(config); + WeightsWrapper c_weights(config); + ForwardPass c_forward(config); ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); @@ -474,9 +479,9 @@ TEST(BackPropTest, EndToEnd) { } } -template -void MulByConstAndAddT(T c, const CompressedLayer& x, - CompressedLayer& out) { +template +void MulByConstAndAddT(T c, const LayerWeightsPtrs& x, + LayerWeightsPtrs& out) { MulByConstAndAddT(c, x.pre_attention_norm_scale, out.pre_attention_norm_scale); MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); @@ -486,23 +491,23 @@ void MulByConstAndAddT(T c, const CompressedLayer& x, MulByConstAndAddT(c, x.linear_w, out.linear_w); } -template -void MulByConstAndAddT(T c, const CompressedWeights& x, - CompressedWeights& out) { - static constexpr size_t kLayers = TConfig::kLayers; +template +void MulByConstAndAddT(T c, const ModelWeightsPtrs& x, + ModelWeightsPtrs& out) { + const size_t layers = x.c_layers.size(); MulByConstAndAddT(c, x.embedder_input_embedding, out.embedder_input_embedding); MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); - for (size_t i = 0; i < kLayers; ++i) { + for (size_t i = 0; i < layers; ++i) { MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); } } // Evaluates forward pass on a batch. -template +template T CrossEntropyLossForwardPass(const std::vector& batch, - const WeightsWrapper& weights, - ForwardPass& forward) { + const WeightsWrapper& weights, + ForwardPass& forward) { T loss = 0.0; for (const Prompt& prompt : batch) { loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); @@ -514,12 +519,11 @@ T CrossEntropyLossForwardPass(const std::vector& batch, // Evaluates forward pass on a batch by applying gradient with the given // learning rate. Does not update weights, but uses the given tmp weights // instead. -template +template T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, - const WeightsWrapper& weights, - const WeightsWrapper& grad, - WeightsWrapper& tmp, - ForwardPass& forward) { + const WeightsWrapper& weights, + const WeightsWrapper& grad, + WeightsWrapper& tmp, ForwardPass& forward) { tmp.CopyFrom(weights); const T scale = -learning_rate / batch.size(); MulByConstAndAddT(scale, grad.get(), tmp.get()); @@ -529,11 +533,9 @@ T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, // Uses line search in the negative gradient direction to update weights. We do // this so that we can test that each step during the gradient descent can // decrease the objective function value. -template -T FindOptimalUpdate(const WeightsWrapper& grad, - WeightsWrapper& weights, - WeightsWrapper& tmp, - ForwardPass& forward, +template +T FindOptimalUpdate(const WeightsWrapper& grad, WeightsWrapper& weights, + WeightsWrapper& tmp, ForwardPass& forward, const std::vector& batch, T loss, T initial_learning_rate) { T lr0 = initial_learning_rate; @@ -568,13 +570,14 @@ TEST(BackProptest, Convergence) { std::mt19937 gen(42); using T = float; using TC = std::complex; - WeightsWrapper> weights; - WeightsWrapper> grad; - WeightsWrapper> tmp; - ForwardPass> forward; - ForwardPass> backward; - WeightsWrapper> c_weights; - ForwardPass> c_forward; + ModelConfig config = TestConfig(); + WeightsWrapper weights(config); + WeightsWrapper grad(config); + WeightsWrapper tmp(config); + ForwardPass forward(config); + ForwardPass backward(config); + WeightsWrapper c_weights(config); + ForwardPass c_forward(config); constexpr size_t kBatchSize = 5; ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); T learning_rate = 0.01; diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 01c5e73..2b82c12 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -19,7 +19,6 @@ #include -#include #include #include // std::abs #include @@ -34,7 +33,6 @@ #include "backprop/test_util.h" #include "gemma/activations.h" #include "gemma/configs.h" -#include "gemma/weights.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -50,6 +48,7 @@ #include "backprop/forward-inl.h" #include "compression/compress.h" #include "ops/ops-inl.h" +#include "util/allocator.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -85,8 +84,8 @@ void TestMatMulVJP() { }; grad.ZeroInit(); - MatMulVJP(weights.data(), x.data(), dy.data(), kTokens, - grad.data(), dx.data(), pool); + MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens, + grad.data(), dx.data(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -130,9 +129,8 @@ void TestMultiHeadMatMulVJP() { }; grad.ZeroInit(); - MultiHeadMatMulVJP( - weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(), - pool); + MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols, + kRows, kTokens, grad.data(), dx.data(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -186,63 +184,63 @@ void TestRMSNormVJP() { } } -template -struct TestConfig : ConfigBaseGemmaV2 { - using Weight = T; - static constexpr int kSeqLen = 24; - static constexpr int kVocabSize = 16; - static constexpr int kModelDim = 32; - static constexpr int kHeads = 3; - static constexpr int kQKVDim = 16; - static constexpr int kFFHiddenDim = 64; - static constexpr std::array kLayerConfig = - FixedLayerConfig<2>(LayerAttentionType::kGemma); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - static constexpr int kKVHeads = 1; - static constexpr int kGemmaLayers = kLayers; -}; +static ModelConfig TestConfig() { + ModelConfig config; + config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", + "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; + config.model_dim = 32; + config.vocab_size = 16; + config.seq_len = 24; + LayerConfig layer_config = { + .model_dim = config.model_dim, + .ff_hidden_dim = 64, + .heads = 3, + .kv_heads = 1, + .qkv_dim = 16, + }; + config.layer_configs = {2, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); + // This is required for optimize_test to pass. + config.att_cap = 50.0f; + config.final_cap = 30.0f; + return config; +} void TestEndToEnd() { std::mt19937 gen(42); hwy::ThreadPool pool(0); - using WeightsF = CompressedWeights>; - using LayerF = CompressedLayer>; - WeightsWrapper> weights; - WeightsWrapper> grad; - ActivationsWrapper> forward0; - ActivationsWrapper> forward1; - ActivationsWrapper> backward; + ModelConfig config = TestConfig(); + WeightsWrapper weights(config); + WeightsWrapper grad(config); + ForwardPass forward0(config); + ForwardPass forward1(config); + ForwardPass backward(config); using TC = std::complex; - WeightsWrapper> c_weights; - ForwardPass> c_forward; + WeightsWrapper c_weights(config); + ForwardPass c_forward(config); ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); - RowVectorBatch inv_timescale = - Activations::CreateInvTimescale>(); + RowVectorBatch inv_timescale = Activations::CreateInvTimescale( + config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0f, gen); - float loss0 = CrossEntropyLossForwardPass( - prompt, weights.get(), forward0.get()); + float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0); - float loss1 = - CrossEntropyLossForwardPass, WeightsF, LayerF>( - prompt.tokens, prompt.context_size, weights.get(), forward1.get(), - inv_timescale, pool); + float loss1 = CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights.get(), forward1, + inv_timescale, pool); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); grad.ZeroInit(); - CrossEntropyLossBackwardPass, WeightsF, LayerF>( - prompt, weights.get(), forward1.get(), grad.get(), backward.get(), - inv_timescale, pool); + CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), + backward, inv_timescale, pool); Complexify(weights.get(), c_weights.get()); auto func = [&]() { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index b6b1dc0..ca969c4 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -26,6 +26,7 @@ #include "backprop/activations.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "gemma/weights.h" #include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -93,29 +94,29 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, return loss * scaling; } -template -void ApplyForwardLayer(const LayerT& weights, - ForwardLayer& activations, - size_t num_tokens, float* HWY_RESTRICT output, +template +void ApplyForwardLayer(const LayerWeightsPtrs& weights, + ForwardLayer& activations, size_t num_tokens, + float* HWY_RESTRICT output, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static const float kQueryScale = + const LayerConfig& config = weights.layer_config; + const size_t model_dim = config.model_dim; + const size_t kSeqLen = activations.input.Rows(); + const size_t kQKVDim = config.qkv_dim; + const size_t kHeads = config.heads; + static const float query_scale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); HWY_ASSERT(num_tokens <= kSeqLen); ApplyRMSNorm(weights.pre_attention_norm_scale.data(), - activations.input.data(), kModelDim, num_tokens, + activations.input.data(), model_dim, num_tokens, activations.pre_att_rms_out.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec<(kHeads + 2) * kQKVDim, kModelDim>( - weights.qkv_einsum_w, 0, - activations.pre_att_rms_out.data() + pos * kModelDim, - activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); + MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim, + activations.pre_att_rms_out.data() + pos * model_dim, + activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); } const size_t num_tasks = kHeads * num_tokens; @@ -130,7 +131,7 @@ void ApplyForwardLayer(const LayerT& weights, float* HWY_RESTRICT q = activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; Rope(q, kQKVDim, inv_timescale.Const(), pos); - MulByConst(kQueryScale, q, kQKVDim); + MulByConst(query_scale, q, kQKVDim); }); pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { @@ -174,29 +175,29 @@ void ApplyForwardLayer(const LayerT& weights, activations.attention_out.ZeroInit(); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t head = 0; head < kHeads; ++head) { - MatVec( - weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, + MatVec( + weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim, + kQKVDim, activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, - activations.att_post1.data() + pos * kModelDim, pool); - AddFrom(activations.att_post1.data() + pos * kModelDim, - activations.attention_out.data() + pos * kModelDim, kModelDim); + activations.att_post1.data() + pos * model_dim, pool); + AddFrom(activations.att_post1.data() + pos * model_dim, + activations.attention_out.data() + pos * model_dim, model_dim); } } for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.input.data() + pos * kModelDim, - activations.attention_out.data() + pos * kModelDim, kModelDim); + AddFrom(activations.input.data() + pos * model_dim, + activations.attention_out.data() + pos * model_dim, model_dim); } ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), - activations.attention_out.data(), kModelDim, num_tokens, + activations.attention_out.data(), model_dim, num_tokens, activations.bf_pre_ffw_rms_out.data(), pool); - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + const size_t kFFHiddenDim = config.ff_hidden_dim; for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec( - weights.gating_einsum_w, 0, - activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, - activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); + MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, + activations.bf_pre_ffw_rms_out.data() + pos * model_dim, + activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t hidden_offset = pos * kFFHiddenDim * 2; @@ -215,77 +216,76 @@ void ApplyForwardLayer(const LayerT& weights, } } for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec( - weights.linear_w, 0, - activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, - output + pos * kModelDim, pool); + MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim, + activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, + output + pos * model_dim, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.attention_out.data() + pos * kModelDim, - output + pos * kModelDim, kModelDim); + AddFrom(activations.attention_out.data() + pos * model_dim, + output + pos * model_dim, model_dim); } } -template +template float CrossEntropyLossForwardPass(const std::vector& prompt, - size_t context_size, const WeightsT& weights, - ForwardPass& forward, + size_t context_size, + const ModelWeightsPtrs& weights, + ForwardPass& forward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kLayers = TConfig::kLayers; - const float kEmbScaling = EmbeddingScaling(); - static_assert(!TConfig::kAbsolutePE); - static_assert(TConfig::kPostNorm == PostNormType::None); - static_assert(TConfig::kKVHeads == 1); + const ModelConfig& config = weights.weights_config; + const size_t vocab_size = config.vocab_size; + const size_t model_dim = config.model_dim; + const size_t layers = config.layer_configs.size(); + const float emb_scaling = EmbeddingScaling(model_dim); + HWY_ASSERT(!config.absolute_pe); + HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); + HWY_ASSERT(config.layer_configs[0].kv_heads == 1); HWY_DASSERT(context_size > 0); HWY_DASSERT(context_size < prompt.size()); const size_t num_tokens = prompt.size() - 1; - InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, - forward.layers[0].input.data(), kModelDim, kVocabSize); + InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling, + forward.layers[0].input.data(), model_dim, vocab_size); - for (size_t layer = 0; layer < kLayers; ++layer) { - auto type = TConfig::kLayerConfig[layer]; + for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { + auto type = config.layer_configs[layer].type; // TODO(szabadka) Implement Griffin layer. HWY_ASSERT(type == LayerAttentionType::kGemma); - float* HWY_RESTRICT output = layer + 1 < kLayers ? - forward.layers[layer + 1].input.data() : - forward.final_layer_output.data(); - ApplyForwardLayer(*weights.GetLayer(layer), - forward.layers[layer], num_tokens, - output, inv_timescale, pool); + float* HWY_RESTRICT output = layer + 1 < layers + ? forward.layers[layer + 1].input.data() + : forward.final_layer_output.data(); + ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer], + num_tokens, output, inv_timescale, pool); } ApplyRMSNorm(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - kModelDim, num_tokens, forward.final_norm_output.data(), pool); + forward.final_layer_output.data(), model_dim, num_tokens, + forward.final_norm_output.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec( - weights.embedder_input_embedding, 0, - forward.final_norm_output.data() + pos * kModelDim, - forward.logits.data() + pos * kVocabSize, pool); + MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim, + forward.final_norm_output.data() + pos * model_dim, + forward.logits.data() + pos * vocab_size, pool); } - if constexpr (TConfig::kFinalCap > 0.0f) { + if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { - LogitsSoftCap(TConfig::kFinalCap, - forward.logits.data() + pos * kVocabSize, kVocabSize); + LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size, + vocab_size); } } hwy::CopyBytes(forward.logits.data(), forward.probs.data(), - num_tokens * kVocabSize * sizeof(forward.logits.At(0))); + num_tokens * vocab_size * sizeof(forward.logits.At(0))); for (size_t pos = 0; pos < num_tokens; ++pos) { - Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); + Softmax(forward.probs.data() + pos * vocab_size, vocab_size); } return CrossEntropyLoss(forward.probs.data(), prompt, context_size, - kVocabSize, pool); + vocab_size, pool); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/backprop/forward.cc b/backprop/forward.cc index 5b2cf1a..0c6cc5c 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -17,8 +17,9 @@ #include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" // Compiles this file for multiple architectures via "foreach_target.h", to @@ -36,38 +37,13 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -float CrossEntropyLossForwardPass(const Prompt& prompt, - const ByteStorageT& weights_u8, - ByteStorageT& forward_u8, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - const auto& weights = - *reinterpret_cast*>(weights_u8.get()); - auto& forward = - *reinterpret_cast*>(forward_u8.get()); - return CrossEntropyLossForwardPass, - CompressedLayer>( - prompt.tokens, prompt.context_size, weights, forward, inv_timescale, - pool); -} - -float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, - const ByteStorageT& weights, - ByteStorageT& forward, +float CrossEntropyLossForwardPassT(const Prompt& prompt, + const ModelWeightsPtrs& weights, + ForwardPass& forward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - // TODO(janwas): use CallFunctorForModel - switch (model) { - case Model::GEMMA_2B: - return CrossEntropyLossForwardPass>( - prompt, weights, forward, inv_timescale, pool); - case Model::GEMMA_TINY: - return CrossEntropyLossForwardPass>( - prompt, weights, forward, inv_timescale, pool); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } + return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size, + weights, forward, inv_timescale, pool); } } // namespace HWY_NAMESPACE @@ -79,13 +55,13 @@ namespace gcpp { HWY_EXPORT(CrossEntropyLossForwardPassT); -float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - ByteStorageT& forward, +float CrossEntropyLossForwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + ForwardPass& forward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( - model, prompt, weights, forward, inv_timescale, pool); + prompt, weights, forward, inv_timescale, pool); } } // namespace gcpp diff --git a/backprop/forward.h b/backprop/forward.h index 92ca371..3b42298 100644 --- a/backprop/forward.h +++ b/backprop/forward.h @@ -16,16 +16,17 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" -#include "gemma/common.h" +#include "gemma/weights.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - ByteStorageT& forward, +float CrossEntropyLossForwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + ForwardPass& forward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool); diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 064112b..617d0c3 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -127,108 +127,107 @@ void InputEmbedding(const T* w, const std::vector& tokens, T scaling, } } -template -void MaskedAttention(const T* qkv, T* output, size_t num_tokens, - size_t kHeads, size_t kQKVDim, size_t kSeqLen) { +template +void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads, + size_t qkv_dim, size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - const size_t qoffset = pos * (kHeads + 2) * kQKVDim; - const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen; - const T* q = qkv + qoffset + head * kQKVDim; + for (size_t head = 0; head < heads; ++head) { + const size_t qoffset = pos * (heads + 2) * qkv_dim; + const size_t aoffset = pos * heads * seq_len + head * seq_len; + const T* q = qkv + qoffset + head * qkv_dim; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; - output[aoffset + pos2] = DotT(q, k, kQKVDim); + const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim; + output[aoffset + pos2] = DotT(q, k, qkv_dim); } } } } -template -void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) { +template +void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + for (size_t head = 0; head < heads; ++head) { + size_t offset = pos * heads * seq_len + head * seq_len; Softmax(x + offset, pos + 1); - memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); } } } -template +template void MixByAttention(const T* qkv, const T* attention, T* output, - size_t num_tokens, size_t kHeads, size_t kQKVDim, - size_t kSeqLen) { + size_t num_tokens, size_t heads, size_t qkv_dim, + size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen]; - T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim]; - memset(out, 0, kQKVDim * sizeof(out[0])); + for (size_t head = 0; head < heads; ++head) { + const T* att = &attention[pos * heads * seq_len + head * seq_len]; + T* out = &output[head * qkv_dim + pos * heads * qkv_dim]; + memset(out, 0, qkv_dim * sizeof(out[0])); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim; const T* v = &qkv[v_offset]; - MulByConstAndAddT(att[pos2], v, out, kQKVDim); + MulByConstAndAddT(att[pos2], v, out, qkv_dim); } } } } -template -void ApplyLayer(const CompressedLayer& weights, - ForwardLayer& activations, size_t num_tokens, - T* output) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim)); +template +void ApplyLayer(const LayerWeightsPtrs& weights, + ForwardLayer& activations, size_t num_tokens, T* output) { + const LayerConfig& layer_config = weights.layer_config; + const size_t model_dim = layer_config.model_dim; + const size_t seq_len = activations.input.Rows(); + const size_t qkv_dim = layer_config.qkv_dim; + const size_t heads = layer_config.heads; + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim)); RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), - activations.pre_att_rms_out.data(), kModelDim, num_tokens); + activations.pre_att_rms_out.data(), model_dim, num_tokens); MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), - activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim, - num_tokens); + activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; - for (size_t h = 0; h <= kHeads; ++h) { - Rope(qkv + h * kQKVDim, kQKVDim, pos); + T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; + for (size_t h = 0; h <= heads; ++h) { + Rope(qkv + h * qkv_dim, qkv_dim, pos); } } for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; - MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; + MulByConstT(query_scale, qkv, heads * qkv_dim); } - MaskedAttention(activations.qkv.data(), activations.att.data(), - num_tokens, kHeads, kQKVDim, kSeqLen); + MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens, + heads, qkv_dim, seq_len); - MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen); + MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len); MixByAttention(activations.qkv.data(), activations.att.data(), - activations.att_out.data(), num_tokens, kHeads, kQKVDim, - kSeqLen); + activations.att_out.data(), num_tokens, heads, qkv_dim, + seq_len); MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), - activations.attention_out.data(), kHeads, kModelDim, kQKVDim, + activations.attention_out.data(), heads, model_dim, qkv_dim, num_tokens); AddFromT(activations.input.data(), activations.attention_out.data(), - num_tokens * kModelDim); + num_tokens * model_dim); RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens); + activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens); MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), - activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim, + activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim, num_tokens); GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), - kFFHiddenDim, num_tokens); + ff_hidden_dim, num_tokens); - MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), - output, kModelDim, kFFHiddenDim, num_tokens); + MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output, + model_dim, ff_hidden_dim, num_tokens); - AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim); + AddFromT(activations.attention_out.data(), output, num_tokens * model_dim); } template @@ -247,48 +246,47 @@ T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { return loss * scaling; } -template +template T CrossEntropyLossForwardPass(const Prompt& prompt, - const CompressedWeights& weights, - ForwardPass& forward) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kLayers = TConfig::kLayers; + const ModelWeightsPtrs& weights, + ForwardPass& forward) { + const ModelConfig& config = weights.weights_config; + const size_t model_dim = config.model_dim; + const size_t vocab_size = config.vocab_size; + const size_t layers = config.layer_configs.size(); const std::vector tokens = prompt.tokens; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - const T kEmbScaling = EmbeddingScaling(kModelDim); - InputEmbedding(weights.embedder_input_embedding.data(), tokens, - kEmbScaling, forward.layers[0].input.data(), kModelDim); + const T kEmbScaling = EmbeddingScaling(model_dim); + InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling, + forward.layers[0].input.data(), model_dim); - for (size_t layer = 0; layer < kLayers; ++layer) { - T* output = layer + 1 < kLayers ? - forward.layers[layer + 1].input.data() : - forward.final_layer_output.data(); + for (size_t layer = 0; layer < layers; ++layer) { + T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data() + : forward.final_layer_output.data(); ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, output); } - RMSNormT(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - forward.final_norm_output.data(), kModelDim, num_tokens); + RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(), + forward.final_norm_output.data(), model_dim, num_tokens); MatMulT(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), - forward.logits.data(), kVocabSize, kModelDim, num_tokens); + forward.final_norm_output.data(), forward.logits.data(), vocab_size, + model_dim, num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { - if constexpr (TConfig::kFinalCap > 0.0f) { - Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, - kVocabSize); + if (config.final_cap > 0.0f) { + Softcap(config.final_cap, forward.logits.data() + pos * vocab_size, + vocab_size); } } memcpy(forward.probs.data(), forward.logits.data(), - num_tokens * kVocabSize * sizeof(forward.logits.At(0))); - Softmax(forward.probs.data(), kVocabSize, num_tokens); + num_tokens * vocab_size * sizeof(forward.logits.At(0))); + Softmax(forward.probs.data(), vocab_size, num_tokens); - return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize); + return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size); } } // namespace gcpp diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 26698c6..b47a48d 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -26,8 +27,10 @@ #include "backprop/optimizer.h" #include "backprop/prompt.h" #include "backprop/sampler.h" +#include "compression/shared.h" #include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/weights.h" #include "util/threading.h" @@ -45,20 +48,18 @@ TEST(OptimizeTest, GradientDescent) { .training = ModelTraining::GEMMA_IT, .weight = Type::kF32, }; - ByteStorageT grad = CallForModelAndWeight( - info.model, info.weight, pool); - ByteStorageT grad_m = CallForModelAndWeight( - info.model, info.weight, pool); - ByteStorageT grad_v = CallForModelAndWeight( - info.model, info.weight, pool); - ByteStorageT forward = - CallForModelAndWeight(info.model, info.weight); - ByteStorageT backward = - CallForModelAndWeight(info.model, info.weight); - KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16); + ModelConfig config = ConfigFromModel(info.model); + ModelWeightsStorage grad, grad_m, grad_v; + grad.Allocate(info.model, info.weight, pool); + grad_m.Allocate(info.model, info.weight, pool); + grad_v.Allocate(info.model, info.weight, pool); + grad_m.ZeroInit(); + grad_v.ZeroInit(); + ForwardPass forward(config), backward(config); + KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); - RowVectorBatch inv_timescale = - Activations::CreateInvTimescale>(); + RowVectorBatch inv_timescale = Activations::CreateInvTimescale( + config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); Gemma gemma(GemmaTokenizer(), info, pools); @@ -92,14 +93,11 @@ TEST(OptimizeTest, GradientDescent) { reply.begin() + context.size()); }; - RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen); - CallForModelAndWeight(info.model, info.weight, - grad_m, pool); - CallForModelAndWeight(info.model, info.weight, - grad_v, pool); + gemma.MutableWeights().RandInit(gen); + gemma.MutableWeights().AllocAndCopyWithTranspose(pool); printf("Initial weights:\n"); - LogWeightStats(info.model, info.weight, gemma.Weights()); + gemma.MutableWeights().LogWeightStats(); constexpr size_t kBatchSize = 8; const float alpha = 0.001f; @@ -113,29 +111,29 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - CallForModelAndWeight(info.model, info.weight, - grad, pool); + grad.ZeroInit(); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { Prompt prompt = training_task.Sample(sgen); total_loss += CrossEntropyLossForwardPass( - info.model, prompt, gemma.Weights(), forward, inv_timescale, pool); - CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward, - grad, backward, inv_timescale, pool); - CallForModelAndWeight( - info.model, info.weight, gemma.MutableWeights(), pool); + prompt, *gemma.Weights().GetWeightsOfType(), forward, + inv_timescale, pool); + CrossEntropyLossBackwardPass( + prompt, *gemma.Weights().GetWeightsOfType(), forward, + *grad.GetWeightsOfType(), backward, inv_timescale, pool); + gemma.MutableWeights().CopyWithTranspose(pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; - AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon, - steps + 1, gemma.Weights(), grad_m, grad_v, pool); + AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1, + gemma.Weights(), grad_m, grad_v, pool); printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", steps, total_loss, num_ok, kBatchSize); if (steps % 100 == 0) { printf("Batch gradient:\n"); - LogWeightStats(info.model, info.weight, grad); + grad.LogWeightStats(); } if (total_loss < 0.5f) { break; @@ -143,7 +141,7 @@ TEST(OptimizeTest, GradientDescent) { } printf("Num steps: %zu\n", steps); printf("Final weights:\n"); - LogWeightStats(info.model, info.weight, gemma.Weights()); + gemma.MutableWeights().LogWeightStats(); EXPECT_LT(steps, 300); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 800f2fa..9187bf7 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -16,7 +16,6 @@ #include "backprop/optimizer.h" #include -#include #include "compression/compress.h" #include "gemma/common.h" @@ -30,37 +29,6 @@ namespace gcpp { namespace { -class WeightInitializer { - public: - WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} - - void operator()(const char* name, hwy::Span tensors) { - float* data = tensors[0]->data(); - for (size_t i = 0; i < tensors[0]->NumElements(); ++i) { - data[i] = dist_(gen_); - } - tensors[0]->set_scale(1.0f); - } - - private: - std::normal_distribution dist_; - std::mt19937& gen_; -}; - -template -struct RandInitWeightsT { - void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool, - std::mt19937& gen) const { - auto& weights = - *reinterpret_cast*>(weights_u8.get()); - // TODO(szabadka) Use the same weight initialization method as in the python - // version. - WeightInitializer init(gen); - CompressedWeights::ForEachTensor({&weights}, - ForEachType::kLoadNoToc, init); - } -}; - class AdamUpdater { public: explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon, @@ -97,42 +65,31 @@ class AdamUpdater { float epsilon_; }; -template -struct AdamUpdateT { - void operator()(const ByteStorageT& grad_u8, float alpha, float beta1, - float beta2, float epsilon, size_t t, - const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, - const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { - using TWeights = CompressedWeights; - auto& grad = *reinterpret_cast(grad_u8.get()); - auto& weights = *reinterpret_cast(weights_u8.get()); - auto& grad_m = *reinterpret_cast(grad_m_u8.get()); - auto& grad_v = *reinterpret_cast(grad_v_u8.get()); - AdamUpdater updater(alpha, beta1, beta2, epsilon, t); - TWeights::ForEachTensor( - {&grad, &weights, &grad_m, &grad_v}, ForEachType::kLoadNoToc, - [&updater](const char* name, hwy::Span tensors) { - updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); - }); - } -}; +void AdamUpdate(ModelWeightsPtrs* grad, float alpha, float beta1, + float beta2, float epsilon, size_t t, + ModelWeightsPtrs* weights, + ModelWeightsPtrs* grad_m, + ModelWeightsPtrs* grad_v, hwy::ThreadPool& pool) { + AdamUpdater updater(alpha, beta1, beta2, epsilon, t); + ModelWeightsPtrs::ForEachTensor( + {grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc, + [&updater](const char* name, hwy::Span tensors) { + updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); + }); +} } // namespace -void RandInitWeights(Model model_type, Type weight_type, - const ByteStorageT& weights, hwy::ThreadPool& pool, - std::mt19937& gen) { +void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha, + float beta1, float beta2, float epsilon, size_t t, + const ModelWeightsStorage& weights, + const ModelWeightsStorage& grad_m, + const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) { HWY_ASSERT(weight_type == Type::kF32); - CallForModel(model_type, weights, pool, gen); -} - -void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, - float alpha, float beta1, float beta2, float epsilon, size_t t, - const ByteStorageT& weights, const ByteStorageT& grad_m, - const ByteStorageT& grad_v, hwy::ThreadPool& pool) { - HWY_ASSERT(weight_type == Type::kF32); - CallForModel(model_type, grad, alpha, beta1, beta2, - epsilon, t, weights, grad_m, grad_v, pool); + AdamUpdate(grad.GetWeightsOfType(), alpha, beta1, beta2, epsilon, t, + weights.GetWeightsOfType(), + grad_m.GetWeightsOfType(), grad_v.GetWeightsOfType(), + pool); } } // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h index b42f311..8b25c52 100644 --- a/backprop/optimizer.h +++ b/backprop/optimizer.h @@ -16,22 +16,17 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ -#include - #include "gemma/common.h" -#include "util/allocator.h" +#include "gemma/weights.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void RandInitWeights(Model model_type, Type weight_type, - const ByteStorageT& weights, hwy::ThreadPool& pool, - std::mt19937& gen); - -void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, - float alpha, float beta1, float beta2, float epsilon, size_t t, - const ByteStorageT& weights, const ByteStorageT& grad_m, - const ByteStorageT& grad_v, hwy::ThreadPool& pool); +void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha, + float beta1, float beta2, float epsilon, size_t t, + const ModelWeightsStorage& weights, + const ModelWeightsStorage& grad_m, + const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/test_util.h b/backprop/test_util.h index bfa2cc5..86f99b1 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -21,11 +21,12 @@ #include #include #include +#include #include "gtest/gtest.h" #include "compression/compress.h" +#include "gemma/configs.h" #include "gemma/weights.h" -#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -39,8 +40,8 @@ void RandInit(MatPtrT& x, T stddev, std::mt19937& gen) { } // TODO: make a member of Layer. -template -void RandInit(CompressedLayer& w, T stddev, std::mt19937& gen) { +template +void RandInit(LayerWeightsPtrs& w, T stddev, std::mt19937& gen) { RandInit(w.pre_attention_norm_scale, stddev, gen); RandInit(w.attn_vec_einsum_w, stddev, gen); RandInit(w.qkv_einsum_w, stddev, gen); @@ -49,9 +50,9 @@ void RandInit(CompressedLayer& w, T stddev, std::mt19937& gen) { RandInit(w.linear_w, stddev, gen); } -template -void RandInit(CompressedWeights& w, T stddev, std::mt19937& gen) { - static constexpr size_t kLayers = TConfig::kLayers; +template +void RandInit(ModelWeightsPtrs& w, T stddev, std::mt19937& gen) { + const size_t kLayers = w.c_layers.size(); RandInit(w.embedder_input_embedding, stddev, gen); RandInit(w.final_norm_scale, stddev, gen); for (size_t i = 0; i < kLayers; ++i) { @@ -66,9 +67,8 @@ void Complexify(const MatPtrT& x, MatPtrT>& c_x) { } } -template -void Complexify(const CompressedLayer& w, - CompressedLayer& c_w) { +template +void Complexify(const LayerWeightsPtrs& w, LayerWeightsPtrs& c_w) { Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); @@ -77,10 +77,9 @@ void Complexify(const CompressedLayer& w, Complexify(w.linear_w, c_w.linear_w); } -template -void Complexify(const CompressedWeights& w, - CompressedWeights& c_w) { - static constexpr size_t kLayers = TConfig::kLayers; +template +void Complexify(const ModelWeightsPtrs& w, ModelWeightsPtrs& c_w) { + const size_t kLayers = w.c_layers.size(); Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); Complexify(w.final_norm_scale, c_w.final_norm_scale); for (size_t i = 0; i < kLayers; ++i) { @@ -88,26 +87,27 @@ void Complexify(const CompressedWeights& w, } } -// Owns weights and provides access to TConfig. -template +// Somewhat duplicates ModelWeightsStorage, but that has neither double nor +// complex types allowed and it would cause code bloat to add them there. +template class WeightsWrapper { public: - WeightsWrapper() - : pool_(0), - data_(AllocateCompressedWeights()(pool_)), - weights_(reinterpret_cast*>(data_.get())) {} + explicit WeightsWrapper(const ModelConfig& config) + : pool_(0), weights_(config, pool_) { + weights_.Allocate(data_, pool_); + } - const CompressedWeights& get() const { return *weights_; } - CompressedWeights& get() { return *weights_; } - void ZeroInit() { weights_->ZeroInit(); } - void CopyFrom(const WeightsWrapper& other) { - get().CopyFrom(other.get()); + const ModelWeightsPtrs& get() const { return weights_; } + ModelWeightsPtrs& get() { return weights_; } + void ZeroInit() { weights_.ZeroInit(); } + void CopyFrom(const WeightsWrapper& other) { + weights_.CopyFrom(other.weights_); } private: hwy::ThreadPool pool_; - ByteStorageT data_; - CompressedWeights* weights_; + std::vector data_; + ModelWeightsPtrs weights_; }; template @@ -173,9 +173,9 @@ void TestGradient(const MatPtrT& grad, MatPtrT>& x, TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); } -template -void TestGradient(const CompressedLayer& grad, - CompressedLayer& c_weights, FUNC func, T max_err) { +template +void TestGradient(const LayerWeightsPtrs& grad, + LayerWeightsPtrs& c_weights, FUNC func, T max_err) { TestGradient(grad.pre_attention_norm_scale, c_weights.pre_attention_norm_scale, func, max_err, max_err, __LINE__); @@ -191,15 +191,15 @@ void TestGradient(const CompressedLayer& grad, func, max_err, max_err, __LINE__); } -template -void TestGradient(const CompressedWeights& grad, - CompressedWeights& c_weights, FUNC func, T max_err) { +template +void TestGradient(const ModelWeightsPtrs& grad, + ModelWeightsPtrs& c_weights, FUNC func, T max_err) { TestGradient(grad.embedder_input_embedding, c_weights.embedder_input_embedding, func, 2 * max_err, max_err, __LINE__); TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, func, max_err, max_err, __LINE__); - for (int i = 0; i < TConfig::kLayers; ++i) { + for (size_t i = 0; i < grad.c_layers.size(); ++i) { TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err); } } diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 24248a1..57f50f5 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include #include @@ -276,6 +275,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { [pfile, &requests, &err](uint64_t i, size_t /*thread*/) { if (!pfile->Read(requests[i].offset, requests[i].size, requests[i].data)) { + fprintf(stderr, "Failed to read blob %zu\n", i); err.test_and_set(); } }); diff --git a/compression/compress.h b/compression/compress.h index e0ea0d7..adb35a1 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -102,8 +102,8 @@ class CompressedArray { class MatPtr { public: // Full constructor for dynamic sizing. - MatPtr(const std::string& name, const std::string& type, size_t element_size, - size_t rows, size_t cols) + MatPtr(const std::string& name, Type type, size_t element_size, size_t rows, + size_t cols) : name_(name), type_(type), element_size_(element_size), @@ -129,7 +129,7 @@ class MatPtr { MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1, const hwy::uint128_t& key2, const hwy::uint128_t& key3) : name_(StringFromKey(key0)), - type_(StringFromKey(key1)), + type_(static_cast(key1.lo)), element_size_(key2.hi), num_elements_(key2.lo), rows_(key3.lo), @@ -138,7 +138,7 @@ class MatPtr { // Adds the contents entry to the table of contents. void AddToToc(std::vector& toc) const { toc.push_back(MakeKey(name_.c_str())); - toc.push_back(MakeKey(type_.c_str())); + toc.push_back({static_cast(type_), 0}); toc.push_back({num_elements_, element_size_}); toc.push_back({rows_, cols_}); } @@ -167,7 +167,7 @@ class MatPtr { void SetName(const std::string& name) { name_ = name; } // Returns the type of the blob. - const std::string& Type() const { return type_; } + Type GetType() const { return type_; } // Returns the size of each element in bytes. size_t ElementSize() const { return element_size_; } @@ -219,8 +219,8 @@ class MatPtr { protected: // Arbitrary name for the array of preferably <= 16 characters. std::string name_; - // Should be the result of TypeName for CallUpcasted() to work. - std::string type_; + // Should be the result of TypeEnum for CallUpcasted() to work. + Type type_; // sizeof(T) size_t element_size_ = 0; // Number of elements in the array. @@ -247,7 +247,7 @@ class MatPtrT : public MatPtr { // Full constructor for dynamic sizing. MatPtrT(const std::string& name, size_t rows, size_t cols) - : MatPtr(name, TypeName(), sizeof(MatT), rows, cols) {} + : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} // Copying allowed as the metadata is small. MatPtrT(const MatPtr& other) : MatPtr(other) {} @@ -330,17 +330,20 @@ class MatPtrT : public MatPtr { template decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { - if (type_ == TypeName()) { + if (type_ == TypeEnum()) { return func(dynamic_cast*>(this), std::forward(args)...); - } else if (type_ == TypeName()) { + } else if (type_ == TypeEnum()) { return func(dynamic_cast*>(this), std::forward(args)...); - } else if (type_ == TypeName()) { + } else if (type_ == TypeEnum()) { return func(dynamic_cast*>(this), std::forward(args)...); + } else if (type_ == TypeEnum()) { + return func(dynamic_cast*>(this), + std::forward(args)...); } else { - HWY_ABORT("Type %s unknown.", type_.c_str()); + HWY_ABORT("Type %d unknown.", type_); } } @@ -563,9 +566,10 @@ class CacheLoader { } // Returns whether all tensors are successfully loaded from cache. - bool ReadAll(hwy::ThreadPool& pool, std::vector& model_memory) { + BlobError ReadAll(hwy::ThreadPool& pool, + std::vector& model_memory) { // reader_ invalid or any Enqueue failed - if (err_ != 0) return false; + if (err_ != 0) return err_; // Setup the model_memory. for (int b = 0; b < model_toc_.size(); ++b) { const std::string& file_key = file_keys_[b]; @@ -574,12 +578,12 @@ class CacheLoader { const MatPtr* toc_blob = file_toc_.Get(file_key); if (toc_blob == nullptr) { fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str()); - return false; + return __LINE__; } if (toc_blob->Rows() != blob->Rows() || toc_blob->Cols() != blob->Cols()) { fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str()); - return false; + return __LINE__; } MatStorage toc_blob_array(*toc_blob); model_memory.push_back(std::move(toc_blob_array)); @@ -603,17 +607,10 @@ class CacheLoader { "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", blob.Name().c_str(), err_, blob.Rows(), blob.Cols(), blob.ElementSize()); - return false; + return err_; } } - - err_ = reader_.ReadAll(pool); - if (err_ != 0) { - fprintf(stderr, "Failed to read all tensors (error %d)\n", err_); - return false; - } - - return true; + return reader_.ReadAll(pool); } private: diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 51897af..1a4fc52 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -24,6 +24,7 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "gemma/configs.h" #ifndef GEMMA_COMPRESS_WEIGHTS_ONCE #define GEMMA_COMPRESS_WEIGHTS_ONCE @@ -150,29 +151,22 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template +template void CompressWeights(const Path& weights_path, const Path& compressed_weights_path, Model model_type, - Type weight_type, hwy::ThreadPool& pool) { + hwy::ThreadPool& pool) { if (!weights_path.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", weights_path.path.c_str()); } printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), compressed_weights_path.path.c_str()); - - using CConfig = typename Configs::c; - using UCConfig = typename Configs::uc; - // Allocate compressed weights. - using CWeights = CompressedWeights; - ByteStorageT c_weights_u8 = AllocateCompressedWeights()(pool); - CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); - - // Allocate uncompressed weights. - using UCWeights = CompressedWeights; - ByteStorageT uc_weights_u8 = AllocateCompressedWeights()(pool); - UCWeights* uc_weights = reinterpret_cast(uc_weights_u8.get()); - + ModelConfig config = ConfigFromModel(model_type); + std::vector model_storage; + ModelWeightsPtrs c_weights(config, pool); + c_weights.Allocate(model_storage, pool); + ModelWeightsPtrs uc_weights(config, pool); + uc_weights.Allocate(model_storage, pool); // Get uncompressed weights, compress, and store. FILE* fptr = fopen(weights_path.path.c_str(), "rb"); if (fptr == nullptr) { @@ -181,22 +175,22 @@ void CompressWeights(const Path& weights_path, } bool ok = true; uint64_t total_size = 0; - CompressedWeights::ForEachTensor( - {uc_weights}, ForEachType::kLoadNoToc, + ModelWeightsPtrs::ForEachTensor( + {&uc_weights}, ForEachType::kLoadNoToc, [&](const char* name, hwy::Span tensors) { fprintf(stderr, "Loading Parameters (size %zu): %s\n", tensors[0]->SizeBytes(), name); ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); total_size += tensors[0]->SizeBytes(); }); - const bool scale_for_compression = UCConfig::kNumTensorScales > 0; + const bool scale_for_compression = config.num_tensor_scales > 0; std::vector scales; if (scale_for_compression) { - uc_weights->GetOrApplyScales(scales); + uc_weights.GetOrApplyScales(scales); } Compressor compressor(pool); - CompressedWeights::ForEachTensor( - {reinterpret_cast*>(uc_weights), c_weights}, + ModelWeightsPtrs::ForEachTensor( + {reinterpret_cast*>(&uc_weights), &c_weights}, ForEachType::kLoadNoToc, [&compressor](const char* name, hwy::Span tensors) { tensors[1]->CallUpcasted( @@ -221,9 +215,26 @@ void Run(Args& args) { HWY_ABORT("PaliGemma is not supported in compress_weights."); } const Type weight_type = args.WeightType(); - GEMMA_EXPORT_AND_DISPATCH( - model_type, weight_type, CompressWeights, - (args.weights, args.compressed_weights, model_type, weight_type, pool)); + switch (weight_type) { + case Type::kF32: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + case Type::kBF16: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + case Type::kSFP: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + case Type::kNUQ: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + default: + HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); + } } } // namespace gcpp diff --git a/compression/shared.h b/compression/shared.h index c216d24..74b7454 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -32,11 +32,6 @@ namespace gcpp { using BF16 = hwy::bfloat16_t; -template -constexpr bool IsF32() { - return hwy::IsSame, float>(); -} - // Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 // inputs that combines the advantages of e4m3 and e5m2 into a single format. // It supports seeking at a granularity of 1 and decoding to bf16/f32. @@ -179,29 +174,67 @@ struct NuqStream { }; #pragma pack(pop) +template +constexpr bool IsF32() { + return hwy::IsSame, float>(); +} + +template +constexpr bool IsBF16() { + return hwy::IsSame, BF16>(); +} + +template +constexpr bool IsSfpStream() { + return hwy::IsSame, SfpStream>(); +} + +template +constexpr bool IsNuqStream() { + return hwy::IsSame, NuqStream>(); +} + +// Instruction-tuned models require extra 'turn structure' tokens in prompts. +enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA }; + +// Tensor types for loading weights. Note that not all types are supported as +// weights for a model, but can be used for other purposes, such as types for +// ModelWeightsPtrs. When adding a new type that is supported, also +// update gemma.cc, weights.*, and add instantiations/new_one.cc. +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; +constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", + "nuq", "f64", "c64", "u128"}; + +// Returns a Type enum for the type of the template parameter. template -const char* TypeName() { +Type TypeEnum() { using Packed = hwy::RemoveCvRef; if constexpr (hwy::IsSame()) { - return "f32"; + return Type::kF32; } else if constexpr (hwy::IsSame()) { - return "b16"; + return Type::kBF16; } else if constexpr (hwy::IsSame()) { - return "sfp"; + return Type::kSFP; } else if constexpr (hwy::IsSame()) { - return "nuq"; + return Type::kNUQ; } else if constexpr (hwy::IsSame()) { - return "f64"; + return Type::kF64; } else if constexpr (hwy::IsSame>()) { - return "c64"; + return Type::kC64; } else if constexpr (hwy::IsSame()) { - return "u128"; + return Type::kU128; } else { HWY_DASSERT(false); - return "unknown"; + return Type::kUnknown; } } +// Returns a string name for the type of the template parameter. +template +const char* TypeName() { + return kTypeStrings[static_cast(TypeEnum())]; +} + template constexpr bool IsCompressed() { return hwy::IsSameEither, SfpStream, NuqStream>(); diff --git a/evals/benchmark.cc b/evals/benchmark.cc index b59079a..1ea4f65 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -128,8 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create( - env.GetModel()->Info().model, env.MutableConfig().prefill_tbatch_size); + KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(), + env.MutableConfig().prefill_tbatch_size); float entropy = ComputeCrossEntropy( *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 63553aa..abae040 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -69,8 +69,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, model_ = AllocateGemma(mutable_loader, pools_); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.resize(1); - kv_caches_[0] = - KVCache::Create(model_->Info().model, inference.prefill_tbatch_size); + kv_caches_[0] = KVCache::Create(model_->GetModelConfig(), + inference.prefill_tbatch_size); } InitGenerator(inference, gen_); runtime_config_ = { @@ -163,7 +163,7 @@ std::vector GemmaEnv::BatchQueryModel( } for (size_t i = 1; i < num_queries; ++i) { if (kv_caches_[i].seq_len == 0) { - kv_caches_[i] = KVCache::Create(model_->Info().model, + kv_caches_[i] = KVCache::Create(model_->GetModelConfig(), runtime_config_.prefill_tbatch_size); } } diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 870f84c..13ff3d3 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -103,8 +103,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, const StreamFunc stream_token = [](int /*token*/, float) { return true; }; // TWeight is unused, but we have to pass it to Config*. - const int vocab_size = - CallForModel(gemma.Info().model); + const int vocab_size = gemma.GetModelConfig().vocab_size; float cross_entropy = std::log(vocab_size); // first token size_t pos = 1; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 39d4f9c..2ed9b64 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -24,7 +24,6 @@ #include // Placeholder for internal header, do not modify. -#include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/tokenizer.h" #include "util/app.h" // LoaderArgs @@ -58,7 +57,8 @@ int main(int argc, char** argv) { gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::KVCache kv_cache = - gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); + gcpp::KVCache::Create(model.GetModelConfig(), + inference.prefill_tbatch_size); size_t generated = 0; // Initialize random number generator diff --git a/gemma/activations.h b/gemma/activations.h index b10b562..3983924 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -21,6 +21,7 @@ #include #include "compression/shared.h" // BF16 +#include "gemma/configs.h" #include "ops/matmul.h" // MatMulEnv #include "util/allocator.h" // RowVectorBatch #include "util/threading.h" @@ -30,6 +31,12 @@ namespace gcpp { struct Activations { + explicit Activations(const ModelConfig& config) + : weights_config(config), + layer_config(config.layer_configs[0]), + seq_len(config.seq_len), + cache_pos_size(config.CachePosSize()) {} + RowVectorBatch x; // input RowVectorBatch q; // query, also KV if MHA. RowVectorBatch logits; @@ -58,23 +65,24 @@ struct Activations { MatMulEnv env; + PostQKType post_qk = PostQKType::Rope; + // And the config. + const ModelConfig& weights_config; + const LayerConfig& layer_config; + size_t seq_len; + size_t cache_pos_size = 0; + // Multi-Head Attention? - template - static constexpr bool IsMHA() { - return TConfig::kHeads == TConfig::kKVHeads; - } + bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; } // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - template - static constexpr size_t QStride() { - return TConfig::kQKVDim * (IsMHA() ? 3 : 1); - } + size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); } - template - static RowVectorBatch CreateInvTimescale() { - constexpr size_t kQKVDim = TConfig::kQKVDim; - const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim; + static RowVectorBatch CreateInvTimescale(size_t qkv_dim, + PostQKType post_qk) { + const size_t rope_dim = + post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; RowVectorBatch inv_timescale(1, rope_dim / 2); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { const float freq_exponents = @@ -86,40 +94,38 @@ struct Activations { return inv_timescale; } - template void Allocate(size_t batch_size, PerClusterPools& pools) { - constexpr size_t kModelDim = TConfig::kModelDim; - constexpr size_t kQKVDim = TConfig::kQKVDim; - constexpr size_t kHeads = TConfig::kHeads; - constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - constexpr size_t kVocabSize = TConfig::kVocabSize; - constexpr size_t kSeqLen = TConfig::kSeqLen; - constexpr size_t kGriffinLayers = TConfig::kGriffinLayers; + post_qk = layer_config.post_qk; + const size_t model_dim = weights_config.model_dim; + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + const size_t vocab_size = weights_config.vocab_size; - x = RowVectorBatch(batch_size, kModelDim); - q = RowVectorBatch(batch_size, kHeads * QStride()); - if constexpr (kVocabSize > 0) { - logits = RowVectorBatch(batch_size, kVocabSize); + x = RowVectorBatch(batch_size, model_dim); + q = RowVectorBatch(batch_size, layer_config.heads * QStride()); + if (vocab_size > 0) { + logits = RowVectorBatch(batch_size, vocab_size); } - pre_att_rms_out = RowVectorBatch(batch_size, kModelDim); - att = RowVectorBatch(batch_size, kHeads * kSeqLen); - att_out = RowVectorBatch(batch_size, kHeads * kQKVDim); - att_sums = RowVectorBatch(batch_size, kModelDim); + pre_att_rms_out = RowVectorBatch(batch_size, model_dim); + att = RowVectorBatch(batch_size, + layer_config.heads * weights_config.seq_len); + att_out = RowVectorBatch(batch_size, + layer_config.heads * layer_config.qkv_dim); + att_sums = RowVectorBatch(batch_size, model_dim); - bf_pre_ffw_rms_out = RowVectorBatch(batch_size, kModelDim); - C1 = RowVectorBatch(batch_size, kFFHiddenDim); - C2 = RowVectorBatch(batch_size, kFFHiddenDim); - ffw_out = RowVectorBatch(batch_size, kModelDim); + bf_pre_ffw_rms_out = RowVectorBatch(batch_size, model_dim); + C1 = RowVectorBatch(batch_size, ff_hidden_dim); + C2 = RowVectorBatch(batch_size, ff_hidden_dim); + ffw_out = RowVectorBatch(batch_size, model_dim); - if constexpr (kGriffinLayers > 0) { - griffin_x = RowVectorBatch(batch_size, kModelDim); - griffin_y = RowVectorBatch(batch_size, kModelDim); - griffin_gate_x = RowVectorBatch(batch_size, kModelDim); - griffin_multiplier = RowVectorBatch(batch_size, kModelDim); + if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { + griffin_x = RowVectorBatch(batch_size, model_dim); + griffin_y = RowVectorBatch(batch_size, model_dim); + griffin_gate_x = RowVectorBatch(batch_size, model_dim); + griffin_multiplier = RowVectorBatch(batch_size, model_dim); } - inv_timescale = CreateInvTimescale(); + inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); env = MatMulEnv(pools); } diff --git a/gemma/common.cc b/gemma/common.cc index e68347b..447deb6 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -15,6 +15,7 @@ #include "gemma/common.h" +#include // sqrtf #include #include @@ -23,6 +24,7 @@ #include #include +#include "compression/shared.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -101,8 +103,6 @@ const char* ModelString(Model model, ModelTraining training) { static_cast(training)); } -constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"}; - const char* StringFromType(Type type) { return kTypeStrings[static_cast(type)]; } @@ -141,4 +141,19 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { prompt = start + prompt + "\nmodel\n"; } } + +float EmbeddingScaling(size_t model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + sqrtf(static_cast(model_dim)))); +} + +float ChooseQueryScale(const ModelConfig& config) { + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / sqrtf(static_cast(config.model_dim / + config.layer_configs[0].heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); +} + } // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h index 18ac5d1..e933e8d 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -16,37 +16,15 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ -#include // sqrtf #include #include -#include "compression/compress.h" #include "gemma/configs.h" // IWYU pragma: export #include "hwy/base.h" // ConvertScalarTo namespace gcpp { -// Model variants: see configs.h for details. When adding a new one, also -// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. -enum class Model { - GEMMA_2B, - GEMMA_7B, - GEMMA2_9B, - GEMMA2_27B, - GRIFFIN_2B, - GEMMA_TINY, - GEMMA2_2B, - PALIGEMMA_224, -}; - -// Instruction-tuned models require extra 'turn structure' tokens in prompts. -enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA }; - -// Tensor types for loading weights. When adding a new one, also -// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. -enum class Type { kF32, kBF16, kSFP }; - // TODO(janwas): merge with functions below. struct ModelInfo { Model model; @@ -66,198 +44,12 @@ const char* StringFromType(Type type); void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); -// Returns the return value of FuncT>().operator()(args), where -// Config* is selected via `model`. Typically called by CallForModelAndWeight, -// but can also be called directly when FuncT does not actually use TWeight. -// -// Note that a T prefix indicates a concrete type template argument, whereas a -// T suffix indicates the argument is itself a template. -// -// `FuncT` must be a functor because function templates cannot be passed as a -// template template argument, and we prefer to avoid the overhead of -// std::function. -template class FuncT, - typename... TArgs> -decltype(auto) CallForModel(Model model, TArgs&&... args) { - switch (model) { - case Model::GEMMA_TINY: - return FuncT>()(std::forward(args)...); - case Model::GEMMA_2B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA_7B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA2_9B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA2_27B: - return FuncT>()(std::forward(args)...); - case Model::GRIFFIN_2B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA2_2B: - return FuncT>()(std::forward(args)...); - case Model::PALIGEMMA_224: - return FuncT>()( - std::forward(args)...); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - -// Returns the return value of FuncT().operator()(args), -// where `TConfig` is selected based on `model` and `weight`. - -// This makes it easy to extend `Model` or `Type` without updating callers. -// -// Usage example: LoadWeights is type-erased so that it can be called from other -// .cc files. It uses this function to call the appropriate instantiation of a -// template functor LoadCompressedWeightsT. -template