diff --git a/BUILD.bazel b/BUILD.bazel index 1e0cc73..c480f23 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -104,8 +104,6 @@ cc_test( tags = ["hwy_ops_test"], deps = [ ":allocator", - ":common", - ":gemma_lib", ":ops", ":test_util", ":threading", @@ -183,7 +181,10 @@ cc_test( cc_library( name = "common", - srcs = ["gemma/common.cc"], + srcs = [ + "gemma/common.cc", + "gemma/configs.cc", + ], hdrs = [ "gemma/common.h", "gemma/configs.h", @@ -195,12 +196,20 @@ cc_library( ], ) +cc_test( + name = "configs_test", + srcs = ["gemma/configs_test.cc"], + deps = [ + ":common", + "@googletest//:gtest_main", + ], +) + cc_library( name = "weights", srcs = ["gemma/weights.cc"], hdrs = ["gemma/weights.h"], deps = [ - ":allocator", ":common", "//compression:compress", "//compression:io", @@ -219,7 +228,6 @@ cc_library( ":common", "//compression:io", "@highway//:hwy", - "@highway//:nanobenchmark", # timer "@highway//:profiler", "@com_google_sentencepiece//:sentencepiece_processor", ], @@ -239,30 +247,10 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/gemma.cc", - "gemma/instantiations/27b_bf16.cc", - "gemma/instantiations/27b_f32.cc", - "gemma/instantiations/27b_sfp.cc", - "gemma/instantiations/2b_bf16.cc", - "gemma/instantiations/2b_f32.cc", - "gemma/instantiations/2b_sfp.cc", - "gemma/instantiations/7b_bf16.cc", - "gemma/instantiations/7b_f32.cc", - "gemma/instantiations/7b_sfp.cc", - "gemma/instantiations/9b_bf16.cc", - "gemma/instantiations/9b_f32.cc", - "gemma/instantiations/9b_sfp.cc", - "gemma/instantiations/tiny_bf16.cc", - "gemma/instantiations/tiny_f32.cc", - "gemma/instantiations/tiny_sfp.cc", - "gemma/instantiations/gr2b_bf16.cc", - "gemma/instantiations/gr2b_f32.cc", - "gemma/instantiations/gr2b_sfp.cc", - "gemma/instantiations/gemma2_2b_bf16.cc", - "gemma/instantiations/gemma2_2b_f32.cc", - "gemma/instantiations/gemma2_2b_sfp.cc", - "gemma/instantiations/paligemma_224_bf16.cc", - "gemma/instantiations/paligemma_224_f32.cc", - "gemma/instantiations/paligemma_224_sfp.cc", + "gemma/instantiations/bf16.cc", + "gemma/instantiations/f32.cc", + "gemma/instantiations/nuq.cc", + "gemma/instantiations/sfp.cc", ], hdrs = [ "gemma/activations.h", @@ -327,8 +315,6 @@ cc_library( ":threading", "//compression:io", "@highway//:hwy", - "@highway//:thread_pool", - "@highway//:topology", ], ) @@ -367,7 +353,6 @@ cc_test( ":benchmark_helper", ":common", ":gemma_lib", - ":tokenizer", "@googletest//:gtest_main", "@highway//:hwy", "@highway//:hwy_test_util", @@ -396,7 +381,6 @@ cc_binary( name = "single_benchmark", srcs = ["evals/benchmark.cc"], deps = [ - ":app", ":args", ":benchmark_helper", ":common", @@ -405,7 +389,6 @@ cc_binary( "//compression:io", "@highway//:hwy", "@highway//:nanobenchmark", - "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -429,13 +412,11 @@ cc_binary( "evals/debug_prompt.cc", ], deps = [ - ":app", ":args", ":benchmark_helper", ":gemma_lib", "//compression:io", "@highway//:hwy", - "@highway//:thread_pool", "@nlohmann_json//:json", ], ) @@ -444,7 +425,6 @@ cc_binary( name = "gemma_mmlu", srcs = ["evals/run_mmlu.cc"], deps = [ - ":app", ":args", ":benchmark_helper", ":gemma_lib", @@ -488,7 +468,6 @@ cc_library( deps = [ ":allocator", ":common", - ":gemma_lib", ":ops", ":prompt", ":weights", @@ -508,7 +487,6 @@ cc_library( "backprop/forward_scalar.h", ], deps = [ - ":allocator", ":common", ":prompt", ":weights", @@ -525,7 +503,6 @@ cc_test( "backprop/test_util.h", ], deps = [ - ":allocator", ":backprop_scalar", ":common", ":prompt", @@ -599,6 +576,7 @@ cc_test( ":threading", ":weights", "@googletest//:gtest_main", + "//compression:sfp", "@highway//:thread_pool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 51ab2e4..bade5de 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,34 +68,15 @@ set(SOURCES gemma/activations.h gemma/common.cc gemma/common.h + gemma/configs.cc gemma/configs.h gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h - gemma/instantiations/27b_bf16.cc - gemma/instantiations/27b_f32.cc - gemma/instantiations/27b_sfp.cc - gemma/instantiations/2b_bf16.cc - gemma/instantiations/2b_f32.cc - gemma/instantiations/2b_sfp.cc - gemma/instantiations/7b_bf16.cc - gemma/instantiations/7b_f32.cc - gemma/instantiations/7b_sfp.cc - gemma/instantiations/9b_bf16.cc - gemma/instantiations/9b_f32.cc - gemma/instantiations/9b_sfp.cc - gemma/instantiations/gr2b_bf16.cc - gemma/instantiations/gr2b_f32.cc - gemma/instantiations/gr2b_sfp.cc - gemma/instantiations/tiny_bf16.cc - gemma/instantiations/tiny_f32.cc - gemma/instantiations/tiny_sfp.cc - gemma/instantiations/gemma2_2b_bf16.cc - gemma/instantiations/gemma2_2b_f32.cc - gemma/instantiations/gemma2_2b_sfp.cc - gemma/instantiations/paligemma_224_bf16.cc - gemma/instantiations/paligemma_224_f32.cc - gemma/instantiations/paligemma_224_sfp.cc + gemma/instantiations/bf16.cc + gemma/instantiations/f32.cc + gemma/instantiations/nuq.cc + gemma/instantiations/sfp.cc gemma/kv_cache.cc gemma/kv_cache.h gemma/tokenizer.cc diff --git a/backprop/activations.h b/backprop/activations.h index 4f2e821..c616759 100644 --- a/backprop/activations.h +++ b/backprop/activations.h @@ -18,32 +18,27 @@ #include -#include +#include #include "compression/compress.h" // MatStorageT -#include "util/allocator.h" // ByteStorageT +#include "gemma/configs.h" // ModelConfig namespace gcpp { -template +template struct ForwardLayer { - ForwardLayer() - : input("input", kSeqLen, kModelDim), - pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim), - qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim), - att("att", kSeqLen * kHeads, kSeqLen), - att_out("att_out", kSeqLen * kHeads, kQKVDim), - att_post1("att_post1", kSeqLen, kModelDim), - attention_out("attention_out", kSeqLen, kModelDim), - bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim), - ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2), - ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {} - - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + ForwardLayer(const LayerConfig& config, size_t seq_len) + : input("input", seq_len, config.model_dim), + pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim), + qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim), + att("att", seq_len * config.heads, seq_len), + att_out("att_out", seq_len * config.heads, config.qkv_dim), + att_post1("att_post1", seq_len, config.model_dim), + attention_out("attention_out", seq_len, config.model_dim), + bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim), + ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2), + ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim), + layer_config(config) {} MatStorageT input; MatStorageT pre_att_rms_out; @@ -55,56 +50,30 @@ struct ForwardLayer { MatStorageT bf_pre_ffw_rms_out; MatStorageT ffw_hidden; MatStorageT ffw_hidden_gated; + const LayerConfig& layer_config; }; -template +template struct ForwardPass { - ForwardPass() - : final_layer_output("final_layer_output", kSeqLen, kModelDim), - final_norm_output("final_norm_output", kSeqLen, kModelDim), - logits("logits", kSeqLen, kVocabSize), - probs("probs", kSeqLen, kVocabSize) { - } // prevents placement-new calling memset + ForwardPass(const ModelConfig& config) + : final_layer_output("final_layer_output", config.seq_len, + config.model_dim), + final_norm_output("final_norm_output", config.seq_len, + config.model_dim), + logits("logits", config.seq_len, config.vocab_size), + probs("probs", config.seq_len, config.vocab_size), + weights_config(config) { + for (const auto& layer_config : config.layer_configs) { + layers.emplace_back(layer_config, config.seq_len); + } + } - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kLayers = TConfig::kLayers; - - std::array, kLayers> layers; + std::vector> layers; MatStorageT final_layer_output; MatStorageT final_norm_output; MatStorageT logits; MatStorageT probs; -}; - -template -struct AllocateForwardPass { - ByteStorageT operator()() const { - ByteStorageT c_weights_u8 = AllocateSizeof>(); - auto* c_weights = - reinterpret_cast*>(c_weights_u8.get()); - new (c_weights) ForwardPass(); - return c_weights_u8; - } -}; - -// Owns activations and undoes the type erasure of AllocateAligned. -template -class ActivationsWrapper { - using WrappedT = ForwardPass; - - public: - ActivationsWrapper() - : data_(AllocateSizeof()), - activations_(*(new(data_.get()) WrappedT())) {} - - const WrappedT& get() const { return activations_; } - WrappedT& get() { return activations_; } - - private: - ByteStorageT data_; - WrappedT& activations_; + const ModelConfig& weights_config; }; } // namespace gcpp diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index f765a5a..2a0f330 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -28,6 +28,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/common.h" +#include "gemma/weights.h" #include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -53,45 +54,41 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -template -void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, - const float* HWY_RESTRICT x, // num_tokens * kCols - const float* HWY_RESTRICT v, // num_tokens * kRows - size_t num_tokens, - float* HWY_RESTRICT grad_w, // kRows * kCols, - float* HWY_RESTRICT grad_x, // num_tokens * kCols - hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); +HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, + const float* HWY_RESTRICT x, // num_tokens * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows + size_t cols, size_t rows, size_t num_tokens, + float* HWY_RESTRICT grad_w, // kRows * kCols, + float* HWY_RESTRICT grad_x, // num_tokens * kCols + hwy::ThreadPool& pool) { + hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t voffs = pos * kRows; - const size_t xoffs = pos * kCols; - for (size_t j = 0; j < kRows; ++j) { - MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols); - MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs], - kCols); + const size_t voffs = pos * rows; + const size_t xoffs = pos * cols; + for (size_t j = 0; j < rows; ++j) { + MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols); + MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols); } } } -template -void MultiHeadMatMulVJP( - const float* HWY_RESTRICT weights, // kHeads * kRows * kCols - const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols +HWY_INLINE void MultiHeadMatMulVJP( + const float* HWY_RESTRICT weights, // heads * kRows * kCols + const float* HWY_RESTRICT x, // num_tokens * heads * kCols const float* HWY_RESTRICT v, // num_tokens * kRows - size_t num_tokens, - float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols - float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols + size_t heads, size_t cols, size_t rows, size_t num_tokens, + float* HWY_RESTRICT grad_w, // heads * kRows * kCols + float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols hwy::ThreadPool& pool) { - hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); + hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t j = 0; j < kRows; ++j) { - for (size_t h = 0; h < kHeads; ++h) { - MulByConstAndAdd(v[pos * kRows + j], - &x[pos * kHeads * kCols + h * kCols], - &grad_w[h * kRows * kCols + j * kCols], kCols); - MulByConstAndAdd(v[pos * kRows + j], - &weights[h * kRows * kCols + j * kCols], - &grad_x[pos * kHeads * kCols + h * kCols], kCols); + for (size_t j = 0; j < rows; ++j) { + for (size_t h = 0; h < heads; ++h) { + MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols], + &grad_w[h * rows * cols + j * cols], cols); + MulByConstAndAdd(v[pos * rows + j], + &weights[h * rows * cols + j * cols], + &grad_x[pos * heads * cols + h * cols], cols); } } } @@ -168,39 +165,39 @@ static HWY_NOINLINE void InputEmbeddingVJP( } } -template -void LayerVJP(const LayerT& weights, - const ForwardLayer& forward, +template +void LayerVJP(const LayerWeightsPtrs& weights, + const ForwardLayer& forward, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, - LayerT& grad, ForwardLayer& backward, + LayerWeightsPtrs& grad, ForwardLayer& backward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static const float kQueryScale = - static_cast(1.0 / sqrt(static_cast(kQKVDim))); - HWY_ASSERT(num_tokens <= kSeqLen); + const LayerConfig& config = weights.layer_config; + const size_t model_dim = config.model_dim; + const size_t qkv_dim = config.qkv_dim; + const size_t heads = config.heads; + const size_t seq_len = forward.input.Rows(); + const size_t ff_hidden_dim = config.ff_hidden_dim; + const float query_scale = + static_cast(1.0 / sqrt(static_cast(qkv_dim))); + HWY_ASSERT(num_tokens <= seq_len); - MatMulVJP( - weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad, - num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(), - pool); + MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(), + next_layer_grad, ff_hidden_dim, model_dim, num_tokens, + grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t hidden_offset = pos * kFFHiddenDim * 2; + const size_t hidden_offset = pos * ff_hidden_dim * 2; const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; - const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim; + const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim; const float* HWY_RESTRICT b_out_gated = - backward.ffw_hidden_gated.data() + pos * kFFHiddenDim; + backward.ffw_hidden_gated.data() + pos * ff_hidden_dim; float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim; + float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; DF df; - for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { + for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) { const auto y = Load(df, f_out + i); const auto x = Load(df, f_out_mul + i); const auto v = Load(df, b_out_gated + i); @@ -209,101 +206,94 @@ void LayerVJP(const LayerT& weights, } } - MatMulVJP( - weights.gating_einsum_w.data(), - forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), - num_tokens, grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), pool); - RMSNormVJP(weights.pre_ffw_norm_scale.data(), - forward.attention_out.data(), - backward.bf_pre_ffw_rms_out.data(), - kModelDim, num_tokens, - grad.pre_ffw_norm_scale.data(), - backward.attention_out.data(), pool); + MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), + backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2, + num_tokens, grad.gating_einsum_w.data(), + backward.bf_pre_ffw_rms_out.data(), pool); + RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), + backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens, + grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), + pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(next_layer_grad + pos * kModelDim, - backward.attention_out.data() + pos * kModelDim, kModelDim); + AddFrom(next_layer_grad + pos * model_dim, + backward.attention_out.data() + pos * model_dim, model_dim); } backward.qkv.ZeroInit(); - MultiHeadMatMulVJP( - weights.attn_vec_einsum_w.data(), forward.att_out.data(), - backward.attention_out.data(), num_tokens, - grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool); + MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(), + backward.attention_out.data(), heads, qkv_dim, model_dim, + num_tokens, grad.attn_vec_einsum_w.data(), + backward.att_out.data(), pool); - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t aoffset = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; const float* HWY_RESTRICT b_att_out = - backward.att_out.data() + (pos * kHeads + head) * kQKVDim; + backward.att_out.data() + (pos * heads + head) * qkv_dim; float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim; const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; - b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim); - MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim); + b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim); + MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim); } } } - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t aoffset = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; SoftmaxVJP(f_head_att, b_head_att, pos + 1); } } - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; - const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim; + const size_t aoffs = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim; const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; - MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim); - MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim); + MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim); + MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim); } } } for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = - backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos); + backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim; + Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); } - for (size_t head = 0; head < kHeads; ++head) { + for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT b_q = - backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; - MulByConst(kQueryScale, b_q, kQKVDim); - Rope(b_q, kQKVDim, inv_timescale.Const(), -pos); + backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim; + MulByConst(query_scale, b_q, qkv_dim); + Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); } } - MatMulVJP( - weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), - backward.qkv.data(), num_tokens, - grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); - RMSNormVJP(weights.pre_attention_norm_scale.data(), - forward.input.data(), - backward.pre_att_rms_out.data(), - kModelDim, num_tokens, - grad.pre_attention_norm_scale.data(), - backward.input.data(), pool); + MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), + backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens, + grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); + RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(), + backward.pre_att_rms_out.data(), model_dim, num_tokens, + grad.pre_attention_norm_scale.data(), backward.input.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(backward.attention_out.data() + pos * kModelDim, - backward.input.data() + pos * kModelDim, kModelDim); + AddFrom(backward.attention_out.data() + pos * model_dim, + backward.input.data() + pos * model_dim, model_dim); } } @@ -342,20 +332,22 @@ static HWY_NOINLINE void CrossEntropyLossGrad( } } -template -void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, - const ForwardPass& forward, - WeightsT& grad, - ForwardPass& backward, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kLayers = TConfig::kLayers; - const float kEmbScaling = EmbeddingScaling(); - static_assert(!TConfig::kAbsolutePE); - static_assert(TConfig::kPostNorm == PostNormType::None); - static_assert(TConfig::kKVHeads == 1); +template +void CrossEntropyLossBackwardPassInl(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, + RowVectorBatch& inv_timescale, + hwy::ThreadPool& pool) { + const ModelConfig& config = weights.weights_config; + const size_t kVocabSize = config.vocab_size; + const size_t model_dim = config.model_dim; + const size_t kLayers = config.layer_configs.size(); + const float kEmbScaling = EmbeddingScaling(model_dim); + HWY_ASSERT(!config.absolute_pe); + HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); + HWY_ASSERT(config.layer_configs[0].kv_heads == 1); HWY_DASSERT(prompt.context_size > 0); HWY_DASSERT(prompt.context_size < prompt.tokens.size()); @@ -370,42 +362,38 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, kVocabSize); } - if constexpr (TConfig::kFinalCap > 0.0f) { + if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, + SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize, backward.logits.data() + pos * kVocabSize, kVocabSize); } } - MatMulVJP( - weights.embedder_input_embedding.data(), forward.final_norm_output.data(), - backward.logits.data(), num_tokens, - grad.embedder_input_embedding.data(), backward.final_norm_output.data(), - pool); + MatMulVJP(weights.embedder_input_embedding.data(), + forward.final_norm_output.data(), backward.logits.data(), model_dim, + kVocabSize, num_tokens, grad.embedder_input_embedding.data(), + backward.final_norm_output.data(), pool); - RMSNormVJP(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - backward.final_norm_output.data(), - kModelDim, num_tokens, - grad.final_norm_scale.data(), - backward.final_layer_output.data(), pool); + RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(), + backward.final_norm_output.data(), model_dim, num_tokens, + grad.final_norm_scale.data(), backward.final_layer_output.data(), + pool); for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { - auto type = TConfig::kLayerConfig[layer]; + auto layer_config = config.layer_configs[layer]; // TODO(szabadka) Implement Griffin layer vjp. - HWY_ASSERT(type == LayerAttentionType::kGemma); + HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma); float* next_layer_grad = layer + 1 < kLayers ? backward.layers[layer + 1].input.data() : backward.final_layer_output.data(); - LayerVJP(*weights.GetLayer(layer), forward.layers[layer], - next_layer_grad, num_tokens, - *grad.GetLayer(layer), backward.layers[layer], - inv_timescale, pool); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + num_tokens, *grad.GetLayer(layer), backward.layers[layer], + inv_timescale, pool); } InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), kModelDim); + grad.embedder_input_embedding.data(), model_dim); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/backprop/backward.cc b/backprop/backward.cc index c186952..868b391 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -38,44 +38,15 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const ByteStorageT& weights_u8, - const ByteStorageT& forward_u8, - ByteStorageT& grad_u8, - ByteStorageT& backward_u8, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - using TWeights = CompressedWeights; - const auto& weights = *reinterpret_cast(weights_u8.get()); - auto& grad = *reinterpret_cast(grad_u8.get()); - using TAct = ForwardPass; - const auto& forward = *reinterpret_cast(forward_u8.get()); - auto& backward = *reinterpret_cast(backward_u8.get()); - CrossEntropyLossBackwardPass, - CompressedLayer>( - prompt, weights, forward, grad, backward, inv_timescale, pool); -} - -void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt, - const ByteStorageT& weights, - const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, +void CrossEntropyLossBackwardPassT(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - // TODO(janwas): use CallFunctorForModel - switch (model) { - case Model::GEMMA_2B: - CrossEntropyLossBackwardPass>( - prompt, weights, forward, grad, backward, inv_timescale, pool); - break; - case Model::GEMMA_TINY: - CrossEntropyLossBackwardPass>( - prompt, weights, forward, grad, backward, inv_timescale, pool); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } + CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward, + inv_timescale, pool); } } // namespace HWY_NAMESPACE @@ -87,14 +58,15 @@ namespace gcpp { HWY_EXPORT(CrossEntropyLossBackwardPassT); -void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( - model, prompt, weights, forward, grad, backward, inv_timescale, pool); + prompt, weights, forward, grad, backward, inv_timescale, pool); } } // namespace gcpp diff --git a/backprop/backward.h b/backprop/backward.h index 0ac218a..d8e50c7 100644 --- a/backprop/backward.h +++ b/backprop/backward.h @@ -16,17 +16,19 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" -#include "gemma/common.h" +#include "gemma/weights.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool); diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index a804cd3..b0a37b3 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -125,65 +125,64 @@ void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) { } } - -template +template void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, - size_t num_tokens, size_t kHeads, size_t kQKVDim, - size_t kSeqLen) { + size_t num_tokens, size_t kHeads, size_t qkv_dim, + size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = pos * (kHeads + 2) * kQKVDim; - memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0])); + const size_t offset = pos * (kHeads + 2) * qkv_dim; + memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0])); } for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; - const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim; + const size_t aoffs = head * seq_len + pos * kHeads * seq_len; const T* q = qkv + qoffs; const T* dout = doutput + aoffs; T* dq = dqkv + qoffs; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim; const T* k = qkv + koffs; T* dk = dqkv + koffs; - MulByConstAndAddT(dout[pos2], k, dq, kQKVDim); - MulByConstAndAddT(dout[pos2], q, dk, kQKVDim); + MulByConstAndAddT(dout[pos2], k, dq, qkv_dim); + MulByConstAndAddT(dout[pos2], q, dk, qkv_dim); } } } } -template -void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, - size_t kHeads, size_t kSeqLen) { +template +void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads, + size_t seq_len) { for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + size_t offset = pos * kHeads * seq_len + head * seq_len; SoftmaxVJPT(y + offset, dy + offset, pos + 1); - memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); } } } -template +template void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, - T* dqkv, T* dattention, size_t num_tokens, - size_t kHeads, size_t kQKVDim, size_t kSeqLen) { + T* dqkv, T* dattention, size_t num_tokens, size_t kHeads, + size_t qkv_dim, size_t seq_len) { auto v_offset = [&](size_t pos) { - return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim; + return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim; }; for (size_t pos = 0; pos < num_tokens; ++pos) { - memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0])); + memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0])); } for (size_t head = 0; head < kHeads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { - const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim; - const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim; + const size_t aoffset = head * seq_len + pos * kHeads * seq_len; const T* att = &attention[aoffset]; const T* dout = &doutput[offset]; T* datt = &dattention[aoffset]; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim); - MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim); + datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim); + MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim); } } } @@ -199,77 +198,76 @@ void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, } } -template -void LayerVJP(const CompressedLayer& weights, - const ForwardLayer& forward, const T* dy, - CompressedLayer& grad, - ForwardLayer& backward, size_t num_tokens) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim)); +template +void LayerVJP(const LayerWeightsPtrs& weights, + const ForwardLayer& forward, const T* dy, + LayerWeightsPtrs& grad, ForwardLayer& backward, + size_t num_tokens) { + const LayerConfig& layer_config = weights.layer_config; + const size_t model_dim = layer_config.model_dim; + const size_t seq_len = forward.input.Rows(); + const size_t qkv_dim = layer_config.qkv_dim; + const size_t kHeads = layer_config.heads; + const size_t kFFHiddenDim = layer_config.ff_hidden_dim; + const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim)); - MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), - dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(), - kModelDim, kFFHiddenDim, num_tokens); + MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy, + grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim, + kFFHiddenDim, num_tokens); GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim, + backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim, num_tokens); RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), backward.bf_pre_ffw_rms_out.data(), grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), - kModelDim, num_tokens); + model_dim, num_tokens); - AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim); + AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim); MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), backward.attention_out.data(), - grad.attn_vec_einsum_w.data(), - backward.att_out.data(), - kHeads, kModelDim, kQKVDim, num_tokens); + grad.attn_vec_einsum_w.data(), backward.att_out.data(), + kHeads, model_dim, qkv_dim, num_tokens); MixByAttentionVJP(forward.qkv.data(), forward.att.data(), backward.att_out.data(), backward.qkv.data(), - backward.att.data(), num_tokens, kHeads, kQKVDim, - kSeqLen); + backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len); - MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), - num_tokens, kHeads, kSeqLen); + MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads, + seq_len); MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), - backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen); + backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len); for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; - MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; + MulByConstT(kQueryScale, qkv, kHeads * qkv_dim); } for (int pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; for (size_t h = 0; h <= kHeads; ++h) { - Rope(qkv + h * kQKVDim, kQKVDim, -pos); + Rope(qkv + h * qkv_dim, qkv_dim, -pos); } } MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), backward.qkv.data(), grad.qkv_einsum_w.data(), - backward.pre_att_rms_out.data(), - (kHeads + 2) * kQKVDim, kModelDim, num_tokens); + backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim, + num_tokens); RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), backward.pre_att_rms_out.data(), - grad.pre_attention_norm_scale.data(), - backward.input.data(), kModelDim, num_tokens); + grad.pre_attention_norm_scale.data(), backward.input.data(), + model_dim, num_tokens); AddFromT(backward.attention_out.data(), backward.input.data(), - num_tokens * kModelDim); + num_tokens * model_dim); } template @@ -296,56 +294,54 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { } } -template +template void CrossEntropyLossBackwardPass(const Prompt& prompt, - const CompressedWeights& weights, - const ForwardPass& forward, - CompressedWeights& grad, - ForwardPass& backward) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kLayers = TConfig::kLayers; + const ModelWeightsPtrs& weights, + const ForwardPass& forward, + ModelWeightsPtrs& grad, + ForwardPass& backward) { + const ModelConfig& config = weights.weights_config; + const size_t model_dim = config.model_dim; + const size_t vocab_size = config.vocab_size; + const size_t layers = config.layer_configs.size(); const std::vector tokens = prompt.tokens; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, - kVocabSize); + vocab_size); - SoftmaxVJPT(forward.probs.data(), backward.logits.data(), - kVocabSize, num_tokens); + SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size, + num_tokens); - if constexpr (TConfig::kFinalCap > 0.0f) { + if (config.final_cap > 0.0f) { for (size_t i = 0; i < num_tokens; ++i) { - SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize, - backward.logits.data() + i * kVocabSize, kVocabSize); + SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size, + backward.logits.data() + i * vocab_size, vocab_size); } } - MatMulVJPT(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), - backward.logits.data(), - grad.embedder_input_embedding.data(), - backward.final_norm_output.data(), - kVocabSize, kModelDim, num_tokens); + MatMulVJPT( + weights.embedder_input_embedding.data(), forward.final_norm_output.data(), + backward.logits.data(), grad.embedder_input_embedding.data(), + backward.final_norm_output.data(), vocab_size, model_dim, num_tokens); RMSNormVJPT(weights.final_norm_scale.data(), forward.final_layer_output.data(), - backward.final_norm_output.data(), - grad.final_norm_scale.data(), - backward.final_layer_output.data(), kModelDim, num_tokens); + backward.final_norm_output.data(), grad.final_norm_scale.data(), + backward.final_layer_output.data(), model_dim, num_tokens); - for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { - T* next_layer_grad = layer + 1 < kLayers - ? backward.layers[layer + 1].input.data() - : backward.final_layer_output.data(); + for (int layer = static_cast(layers) - 1; layer >= 0; --layer) { + T* next_layer_grad = layer + 1 < layers + ? backward.layers[layer + 1].input.data() + : backward.final_layer_output.data(); LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, *grad.GetLayer(layer), backward.layers[layer], num_tokens); } - const T kEmbScaling = EmbeddingScaling(kModelDim); - InputEmbeddingVJPT(weights.embedder_input_embedding.data(), - tokens, kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), kModelDim); + const T kEmbScaling = EmbeddingScaling(model_dim); + InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens, + kEmbScaling, backward.layers[0].input.data(), + grad.embedder_input_embedding.data(), model_dim); } } // namespace gcpp diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 262a121..b5e39db 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -19,7 +19,6 @@ #include #include // memcpy -#include #include #include #include @@ -384,44 +383,49 @@ TEST(BackPropTest, InputEmbeddingVJP) { } } -template -struct TestConfig : ConfigBaseGemmaV2 { - using Weight = T; - static constexpr int kSeqLen = 18; - static constexpr int kVocabSize = 12; - static constexpr int kModelDim = 32; - static constexpr int kHeads = 3; - static constexpr int kQKVDim = 12; - static constexpr int kFFHiddenDim = 48; - static constexpr std::array kLayerConfig = - FixedLayerConfig<2>(LayerAttentionType::kGemma); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - static constexpr int kKVHeads = 1; - static constexpr int kGemmaLayers = kLayers; -}; +static ModelConfig TestConfig() { + ModelConfig config; + config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", + "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; + config.model_dim = 32; + config.vocab_size = 12; + config.seq_len = 18; + LayerConfig layer_config = { + .model_dim = config.model_dim, + .ff_hidden_dim = 48, + .heads = 3, + .kv_heads = 1, + .qkv_dim = 12, + }; + config.layer_configs = {2, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); + // This is required for optimize_test to pass. + config.final_cap = 30.0f; + return config; +} TEST(BackPropTest, LayerVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim; - CompressedLayer> weights; - CompressedLayer> grad; - ForwardLayer> forward; - ForwardLayer> backward = {}; - CompressedLayer> c_weights; - ForwardLayer> c_forward; - std::array y; + ModelConfig config = TestConfig(); + const size_t kOutputSize = config.seq_len * config.model_dim; + LayerWeightsPtrs weights(config.layer_configs[0]); + LayerWeightsPtrs grad(config.layer_configs[0]); + ForwardLayer forward(config.layer_configs[0], config.seq_len); + ForwardLayer backward(config.layer_configs[0], config.seq_len); + LayerWeightsPtrs c_weights(config.layer_configs[0]); + ForwardLayer c_forward(config.layer_configs[0], config.seq_len); + MatStorageT y("y", kOutputSize, 1); MatStorageT dy("dy", kOutputSize, 1); - std::array c_y; + MatStorageT c_y("c_y", kOutputSize, 1); const size_t num_tokens = 3; - weights.Allocate(); - grad.Allocate(); - c_weights.Allocate(); + std::vector layer_storage; + weights.Allocate(layer_storage); + grad.Allocate(layer_storage); + c_weights.Allocate(layer_storage); backward.input.ZeroInit(); for (size_t iter = 0; iter < 10; ++iter) { @@ -432,7 +436,7 @@ TEST(BackPropTest, LayerVJP) { Complexify(forward.input, c_forward.input); auto func = [&]() { ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); - return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim); + return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim); }; grad.ZeroInit(/*layer_idx=*/0); ApplyLayer(weights, forward, num_tokens, y.data()); @@ -447,12 +451,13 @@ TEST(BackPropTest, EndToEnd) { std::mt19937 gen(42); using T = double; using TC = std::complex; - WeightsWrapper> weights; - WeightsWrapper> grad; - ForwardPass> forward; - ForwardPass> backward; - WeightsWrapper> c_weights; - ForwardPass> c_forward; + ModelConfig config = TestConfig(); + WeightsWrapper weights(config); + WeightsWrapper grad(config); + ForwardPass forward(config); + ForwardPass backward(config); + WeightsWrapper c_weights(config); + ForwardPass c_forward(config); ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); @@ -474,9 +479,9 @@ TEST(BackPropTest, EndToEnd) { } } -template -void MulByConstAndAddT(T c, const CompressedLayer& x, - CompressedLayer& out) { +template +void MulByConstAndAddT(T c, const LayerWeightsPtrs& x, + LayerWeightsPtrs& out) { MulByConstAndAddT(c, x.pre_attention_norm_scale, out.pre_attention_norm_scale); MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); @@ -486,23 +491,23 @@ void MulByConstAndAddT(T c, const CompressedLayer& x, MulByConstAndAddT(c, x.linear_w, out.linear_w); } -template -void MulByConstAndAddT(T c, const CompressedWeights& x, - CompressedWeights& out) { - static constexpr size_t kLayers = TConfig::kLayers; +template +void MulByConstAndAddT(T c, const ModelWeightsPtrs& x, + ModelWeightsPtrs& out) { + const size_t layers = x.c_layers.size(); MulByConstAndAddT(c, x.embedder_input_embedding, out.embedder_input_embedding); MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); - for (size_t i = 0; i < kLayers; ++i) { + for (size_t i = 0; i < layers; ++i) { MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); } } // Evaluates forward pass on a batch. -template +template T CrossEntropyLossForwardPass(const std::vector& batch, - const WeightsWrapper& weights, - ForwardPass& forward) { + const WeightsWrapper& weights, + ForwardPass& forward) { T loss = 0.0; for (const Prompt& prompt : batch) { loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); @@ -514,12 +519,11 @@ T CrossEntropyLossForwardPass(const std::vector& batch, // Evaluates forward pass on a batch by applying gradient with the given // learning rate. Does not update weights, but uses the given tmp weights // instead. -template +template T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, - const WeightsWrapper& weights, - const WeightsWrapper& grad, - WeightsWrapper& tmp, - ForwardPass& forward) { + const WeightsWrapper& weights, + const WeightsWrapper& grad, + WeightsWrapper& tmp, ForwardPass& forward) { tmp.CopyFrom(weights); const T scale = -learning_rate / batch.size(); MulByConstAndAddT(scale, grad.get(), tmp.get()); @@ -529,11 +533,9 @@ T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, // Uses line search in the negative gradient direction to update weights. We do // this so that we can test that each step during the gradient descent can // decrease the objective function value. -template -T FindOptimalUpdate(const WeightsWrapper& grad, - WeightsWrapper& weights, - WeightsWrapper& tmp, - ForwardPass& forward, +template +T FindOptimalUpdate(const WeightsWrapper& grad, WeightsWrapper& weights, + WeightsWrapper& tmp, ForwardPass& forward, const std::vector& batch, T loss, T initial_learning_rate) { T lr0 = initial_learning_rate; @@ -568,13 +570,14 @@ TEST(BackProptest, Convergence) { std::mt19937 gen(42); using T = float; using TC = std::complex; - WeightsWrapper> weights; - WeightsWrapper> grad; - WeightsWrapper> tmp; - ForwardPass> forward; - ForwardPass> backward; - WeightsWrapper> c_weights; - ForwardPass> c_forward; + ModelConfig config = TestConfig(); + WeightsWrapper weights(config); + WeightsWrapper grad(config); + WeightsWrapper tmp(config); + ForwardPass forward(config); + ForwardPass backward(config); + WeightsWrapper c_weights(config); + ForwardPass c_forward(config); constexpr size_t kBatchSize = 5; ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); T learning_rate = 0.01; diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 01c5e73..2b82c12 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -19,7 +19,6 @@ #include -#include #include #include // std::abs #include @@ -34,7 +33,6 @@ #include "backprop/test_util.h" #include "gemma/activations.h" #include "gemma/configs.h" -#include "gemma/weights.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -50,6 +48,7 @@ #include "backprop/forward-inl.h" #include "compression/compress.h" #include "ops/ops-inl.h" +#include "util/allocator.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -85,8 +84,8 @@ void TestMatMulVJP() { }; grad.ZeroInit(); - MatMulVJP(weights.data(), x.data(), dy.data(), kTokens, - grad.data(), dx.data(), pool); + MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens, + grad.data(), dx.data(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -130,9 +129,8 @@ void TestMultiHeadMatMulVJP() { }; grad.ZeroInit(); - MultiHeadMatMulVJP( - weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(), - pool); + MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols, + kRows, kTokens, grad.data(), dx.data(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); @@ -186,63 +184,63 @@ void TestRMSNormVJP() { } } -template -struct TestConfig : ConfigBaseGemmaV2 { - using Weight = T; - static constexpr int kSeqLen = 24; - static constexpr int kVocabSize = 16; - static constexpr int kModelDim = 32; - static constexpr int kHeads = 3; - static constexpr int kQKVDim = 16; - static constexpr int kFFHiddenDim = 64; - static constexpr std::array kLayerConfig = - FixedLayerConfig<2>(LayerAttentionType::kGemma); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kNumTensorScales = 4 * kLayers; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - static constexpr int kKVHeads = 1; - static constexpr int kGemmaLayers = kLayers; -}; +static ModelConfig TestConfig() { + ModelConfig config; + config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", + "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; + config.model_dim = 32; + config.vocab_size = 16; + config.seq_len = 24; + LayerConfig layer_config = { + .model_dim = config.model_dim, + .ff_hidden_dim = 64, + .heads = 3, + .kv_heads = 1, + .qkv_dim = 16, + }; + config.layer_configs = {2, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); + // This is required for optimize_test to pass. + config.att_cap = 50.0f; + config.final_cap = 30.0f; + return config; +} void TestEndToEnd() { std::mt19937 gen(42); hwy::ThreadPool pool(0); - using WeightsF = CompressedWeights>; - using LayerF = CompressedLayer>; - WeightsWrapper> weights; - WeightsWrapper> grad; - ActivationsWrapper> forward0; - ActivationsWrapper> forward1; - ActivationsWrapper> backward; + ModelConfig config = TestConfig(); + WeightsWrapper weights(config); + WeightsWrapper grad(config); + ForwardPass forward0(config); + ForwardPass forward1(config); + ForwardPass backward(config); using TC = std::complex; - WeightsWrapper> c_weights; - ForwardPass> c_forward; + WeightsWrapper c_weights(config); + ForwardPass c_forward(config); ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); - RowVectorBatch inv_timescale = - Activations::CreateInvTimescale>(); + RowVectorBatch inv_timescale = Activations::CreateInvTimescale( + config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0f, gen); - float loss0 = CrossEntropyLossForwardPass( - prompt, weights.get(), forward0.get()); + float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0); - float loss1 = - CrossEntropyLossForwardPass, WeightsF, LayerF>( - prompt.tokens, prompt.context_size, weights.get(), forward1.get(), - inv_timescale, pool); + float loss1 = CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights.get(), forward1, + inv_timescale, pool); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); grad.ZeroInit(); - CrossEntropyLossBackwardPass, WeightsF, LayerF>( - prompt, weights.get(), forward1.get(), grad.get(), backward.get(), - inv_timescale, pool); + CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), + backward, inv_timescale, pool); Complexify(weights.get(), c_weights.get()); auto func = [&]() { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index b6b1dc0..ca969c4 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -26,6 +26,7 @@ #include "backprop/activations.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "gemma/weights.h" #include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -93,29 +94,29 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, return loss * scaling; } -template -void ApplyForwardLayer(const LayerT& weights, - ForwardLayer& activations, - size_t num_tokens, float* HWY_RESTRICT output, +template +void ApplyForwardLayer(const LayerWeightsPtrs& weights, + ForwardLayer& activations, size_t num_tokens, + float* HWY_RESTRICT output, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static const float kQueryScale = + const LayerConfig& config = weights.layer_config; + const size_t model_dim = config.model_dim; + const size_t kSeqLen = activations.input.Rows(); + const size_t kQKVDim = config.qkv_dim; + const size_t kHeads = config.heads; + static const float query_scale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); HWY_ASSERT(num_tokens <= kSeqLen); ApplyRMSNorm(weights.pre_attention_norm_scale.data(), - activations.input.data(), kModelDim, num_tokens, + activations.input.data(), model_dim, num_tokens, activations.pre_att_rms_out.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec<(kHeads + 2) * kQKVDim, kModelDim>( - weights.qkv_einsum_w, 0, - activations.pre_att_rms_out.data() + pos * kModelDim, - activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); + MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim, + activations.pre_att_rms_out.data() + pos * model_dim, + activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); } const size_t num_tasks = kHeads * num_tokens; @@ -130,7 +131,7 @@ void ApplyForwardLayer(const LayerT& weights, float* HWY_RESTRICT q = activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; Rope(q, kQKVDim, inv_timescale.Const(), pos); - MulByConst(kQueryScale, q, kQKVDim); + MulByConst(query_scale, q, kQKVDim); }); pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { @@ -174,29 +175,29 @@ void ApplyForwardLayer(const LayerT& weights, activations.attention_out.ZeroInit(); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t head = 0; head < kHeads; ++head) { - MatVec( - weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, + MatVec( + weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim, + kQKVDim, activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, - activations.att_post1.data() + pos * kModelDim, pool); - AddFrom(activations.att_post1.data() + pos * kModelDim, - activations.attention_out.data() + pos * kModelDim, kModelDim); + activations.att_post1.data() + pos * model_dim, pool); + AddFrom(activations.att_post1.data() + pos * model_dim, + activations.attention_out.data() + pos * model_dim, model_dim); } } for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.input.data() + pos * kModelDim, - activations.attention_out.data() + pos * kModelDim, kModelDim); + AddFrom(activations.input.data() + pos * model_dim, + activations.attention_out.data() + pos * model_dim, model_dim); } ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), - activations.attention_out.data(), kModelDim, num_tokens, + activations.attention_out.data(), model_dim, num_tokens, activations.bf_pre_ffw_rms_out.data(), pool); - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + const size_t kFFHiddenDim = config.ff_hidden_dim; for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec( - weights.gating_einsum_w, 0, - activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, - activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); + MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, + activations.bf_pre_ffw_rms_out.data() + pos * model_dim, + activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t hidden_offset = pos * kFFHiddenDim * 2; @@ -215,77 +216,76 @@ void ApplyForwardLayer(const LayerT& weights, } } for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec( - weights.linear_w, 0, - activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, - output + pos * kModelDim, pool); + MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim, + activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, + output + pos * model_dim, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.attention_out.data() + pos * kModelDim, - output + pos * kModelDim, kModelDim); + AddFrom(activations.attention_out.data() + pos * model_dim, + output + pos * model_dim, model_dim); } } -template +template float CrossEntropyLossForwardPass(const std::vector& prompt, - size_t context_size, const WeightsT& weights, - ForwardPass& forward, + size_t context_size, + const ModelWeightsPtrs& weights, + ForwardPass& forward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kLayers = TConfig::kLayers; - const float kEmbScaling = EmbeddingScaling(); - static_assert(!TConfig::kAbsolutePE); - static_assert(TConfig::kPostNorm == PostNormType::None); - static_assert(TConfig::kKVHeads == 1); + const ModelConfig& config = weights.weights_config; + const size_t vocab_size = config.vocab_size; + const size_t model_dim = config.model_dim; + const size_t layers = config.layer_configs.size(); + const float emb_scaling = EmbeddingScaling(model_dim); + HWY_ASSERT(!config.absolute_pe); + HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); + HWY_ASSERT(config.layer_configs[0].kv_heads == 1); HWY_DASSERT(context_size > 0); HWY_DASSERT(context_size < prompt.size()); const size_t num_tokens = prompt.size() - 1; - InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, - forward.layers[0].input.data(), kModelDim, kVocabSize); + InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling, + forward.layers[0].input.data(), model_dim, vocab_size); - for (size_t layer = 0; layer < kLayers; ++layer) { - auto type = TConfig::kLayerConfig[layer]; + for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { + auto type = config.layer_configs[layer].type; // TODO(szabadka) Implement Griffin layer. HWY_ASSERT(type == LayerAttentionType::kGemma); - float* HWY_RESTRICT output = layer + 1 < kLayers ? - forward.layers[layer + 1].input.data() : - forward.final_layer_output.data(); - ApplyForwardLayer(*weights.GetLayer(layer), - forward.layers[layer], num_tokens, - output, inv_timescale, pool); + float* HWY_RESTRICT output = layer + 1 < layers + ? forward.layers[layer + 1].input.data() + : forward.final_layer_output.data(); + ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer], + num_tokens, output, inv_timescale, pool); } ApplyRMSNorm(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - kModelDim, num_tokens, forward.final_norm_output.data(), pool); + forward.final_layer_output.data(), model_dim, num_tokens, + forward.final_norm_output.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - MatVec( - weights.embedder_input_embedding, 0, - forward.final_norm_output.data() + pos * kModelDim, - forward.logits.data() + pos * kVocabSize, pool); + MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim, + forward.final_norm_output.data() + pos * model_dim, + forward.logits.data() + pos * vocab_size, pool); } - if constexpr (TConfig::kFinalCap > 0.0f) { + if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { - LogitsSoftCap(TConfig::kFinalCap, - forward.logits.data() + pos * kVocabSize, kVocabSize); + LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size, + vocab_size); } } hwy::CopyBytes(forward.logits.data(), forward.probs.data(), - num_tokens * kVocabSize * sizeof(forward.logits.At(0))); + num_tokens * vocab_size * sizeof(forward.logits.At(0))); for (size_t pos = 0; pos < num_tokens; ++pos) { - Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); + Softmax(forward.probs.data() + pos * vocab_size, vocab_size); } return CrossEntropyLoss(forward.probs.data(), prompt, context_size, - kVocabSize, pool); + vocab_size, pool); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/backprop/forward.cc b/backprop/forward.cc index 5b2cf1a..0c6cc5c 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -17,8 +17,9 @@ #include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" // Compiles this file for multiple architectures via "foreach_target.h", to @@ -36,38 +37,13 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -float CrossEntropyLossForwardPass(const Prompt& prompt, - const ByteStorageT& weights_u8, - ByteStorageT& forward_u8, - RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { - const auto& weights = - *reinterpret_cast*>(weights_u8.get()); - auto& forward = - *reinterpret_cast*>(forward_u8.get()); - return CrossEntropyLossForwardPass, - CompressedLayer>( - prompt.tokens, prompt.context_size, weights, forward, inv_timescale, - pool); -} - -float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, - const ByteStorageT& weights, - ByteStorageT& forward, +float CrossEntropyLossForwardPassT(const Prompt& prompt, + const ModelWeightsPtrs& weights, + ForwardPass& forward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { - // TODO(janwas): use CallFunctorForModel - switch (model) { - case Model::GEMMA_2B: - return CrossEntropyLossForwardPass>( - prompt, weights, forward, inv_timescale, pool); - case Model::GEMMA_TINY: - return CrossEntropyLossForwardPass>( - prompt, weights, forward, inv_timescale, pool); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } + return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size, + weights, forward, inv_timescale, pool); } } // namespace HWY_NAMESPACE @@ -79,13 +55,13 @@ namespace gcpp { HWY_EXPORT(CrossEntropyLossForwardPassT); -float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - ByteStorageT& forward, +float CrossEntropyLossForwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + ForwardPass& forward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( - model, prompt, weights, forward, inv_timescale, pool); + prompt, weights, forward, inv_timescale, pool); } } // namespace gcpp diff --git a/backprop/forward.h b/backprop/forward.h index 92ca371..3b42298 100644 --- a/backprop/forward.h +++ b/backprop/forward.h @@ -16,16 +16,17 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ +#include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" -#include "gemma/common.h" +#include "gemma/weights.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, - const ByteStorageT& weights, - ByteStorageT& forward, +float CrossEntropyLossForwardPass(const Prompt& prompt, + const ModelWeightsPtrs& weights, + ForwardPass& forward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool); diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 064112b..617d0c3 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -127,108 +127,107 @@ void InputEmbedding(const T* w, const std::vector& tokens, T scaling, } } -template -void MaskedAttention(const T* qkv, T* output, size_t num_tokens, - size_t kHeads, size_t kQKVDim, size_t kSeqLen) { +template +void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads, + size_t qkv_dim, size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - const size_t qoffset = pos * (kHeads + 2) * kQKVDim; - const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen; - const T* q = qkv + qoffset + head * kQKVDim; + for (size_t head = 0; head < heads; ++head) { + const size_t qoffset = pos * (heads + 2) * qkv_dim; + const size_t aoffset = pos * heads * seq_len + head * seq_len; + const T* q = qkv + qoffset + head * qkv_dim; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; - output[aoffset + pos2] = DotT(q, k, kQKVDim); + const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim; + output[aoffset + pos2] = DotT(q, k, qkv_dim); } } } } -template -void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) { +template +void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + for (size_t head = 0; head < heads; ++head) { + size_t offset = pos * heads * seq_len + head * seq_len; Softmax(x + offset, pos + 1); - memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T)); } } } -template +template void MixByAttention(const T* qkv, const T* attention, T* output, - size_t num_tokens, size_t kHeads, size_t kQKVDim, - size_t kSeqLen) { + size_t num_tokens, size_t heads, size_t qkv_dim, + size_t seq_len) { for (size_t pos = 0; pos < num_tokens; ++pos) { - for (size_t head = 0; head < kHeads; ++head) { - const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen]; - T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim]; - memset(out, 0, kQKVDim * sizeof(out[0])); + for (size_t head = 0; head < heads; ++head) { + const T* att = &attention[pos * heads * seq_len + head * seq_len]; + T* out = &output[head * qkv_dim + pos * heads * qkv_dim]; + memset(out, 0, qkv_dim * sizeof(out[0])); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim; const T* v = &qkv[v_offset]; - MulByConstAndAddT(att[pos2], v, out, kQKVDim); + MulByConstAndAddT(att[pos2], v, out, qkv_dim); } } } } -template -void ApplyLayer(const CompressedLayer& weights, - ForwardLayer& activations, size_t num_tokens, - T* output) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim)); +template +void ApplyLayer(const LayerWeightsPtrs& weights, + ForwardLayer& activations, size_t num_tokens, T* output) { + const LayerConfig& layer_config = weights.layer_config; + const size_t model_dim = layer_config.model_dim; + const size_t seq_len = activations.input.Rows(); + const size_t qkv_dim = layer_config.qkv_dim; + const size_t heads = layer_config.heads; + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim)); RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), - activations.pre_att_rms_out.data(), kModelDim, num_tokens); + activations.pre_att_rms_out.data(), model_dim, num_tokens); MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), - activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim, - num_tokens); + activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; - for (size_t h = 0; h <= kHeads; ++h) { - Rope(qkv + h * kQKVDim, kQKVDim, pos); + T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; + for (size_t h = 0; h <= heads; ++h) { + Rope(qkv + h * qkv_dim, qkv_dim, pos); } } for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; - MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; + MulByConstT(query_scale, qkv, heads * qkv_dim); } - MaskedAttention(activations.qkv.data(), activations.att.data(), - num_tokens, kHeads, kQKVDim, kSeqLen); + MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens, + heads, qkv_dim, seq_len); - MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen); + MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len); MixByAttention(activations.qkv.data(), activations.att.data(), - activations.att_out.data(), num_tokens, kHeads, kQKVDim, - kSeqLen); + activations.att_out.data(), num_tokens, heads, qkv_dim, + seq_len); MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), - activations.attention_out.data(), kHeads, kModelDim, kQKVDim, + activations.attention_out.data(), heads, model_dim, qkv_dim, num_tokens); AddFromT(activations.input.data(), activations.attention_out.data(), - num_tokens * kModelDim); + num_tokens * model_dim); RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens); + activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens); MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), - activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim, + activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim, num_tokens); GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), - kFFHiddenDim, num_tokens); + ff_hidden_dim, num_tokens); - MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), - output, kModelDim, kFFHiddenDim, num_tokens); + MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output, + model_dim, ff_hidden_dim, num_tokens); - AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim); + AddFromT(activations.attention_out.data(), output, num_tokens * model_dim); } template @@ -247,48 +246,47 @@ T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { return loss * scaling; } -template +template T CrossEntropyLossForwardPass(const Prompt& prompt, - const CompressedWeights& weights, - ForwardPass& forward) { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kVocabSize = TConfig::kVocabSize; - static constexpr size_t kLayers = TConfig::kLayers; + const ModelWeightsPtrs& weights, + ForwardPass& forward) { + const ModelConfig& config = weights.weights_config; + const size_t model_dim = config.model_dim; + const size_t vocab_size = config.vocab_size; + const size_t layers = config.layer_configs.size(); const std::vector tokens = prompt.tokens; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - const T kEmbScaling = EmbeddingScaling(kModelDim); - InputEmbedding(weights.embedder_input_embedding.data(), tokens, - kEmbScaling, forward.layers[0].input.data(), kModelDim); + const T kEmbScaling = EmbeddingScaling(model_dim); + InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling, + forward.layers[0].input.data(), model_dim); - for (size_t layer = 0; layer < kLayers; ++layer) { - T* output = layer + 1 < kLayers ? - forward.layers[layer + 1].input.data() : - forward.final_layer_output.data(); + for (size_t layer = 0; layer < layers; ++layer) { + T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data() + : forward.final_layer_output.data(); ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, output); } - RMSNormT(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - forward.final_norm_output.data(), kModelDim, num_tokens); + RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(), + forward.final_norm_output.data(), model_dim, num_tokens); MatMulT(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), - forward.logits.data(), kVocabSize, kModelDim, num_tokens); + forward.final_norm_output.data(), forward.logits.data(), vocab_size, + model_dim, num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { - if constexpr (TConfig::kFinalCap > 0.0f) { - Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, - kVocabSize); + if (config.final_cap > 0.0f) { + Softcap(config.final_cap, forward.logits.data() + pos * vocab_size, + vocab_size); } } memcpy(forward.probs.data(), forward.logits.data(), - num_tokens * kVocabSize * sizeof(forward.logits.At(0))); - Softmax(forward.probs.data(), kVocabSize, num_tokens); + num_tokens * vocab_size * sizeof(forward.logits.At(0))); + Softmax(forward.probs.data(), vocab_size, num_tokens); - return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize); + return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size); } } // namespace gcpp diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 26698c6..b47a48d 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -26,8 +27,10 @@ #include "backprop/optimizer.h" #include "backprop/prompt.h" #include "backprop/sampler.h" +#include "compression/shared.h" #include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/gemma.h" #include "gemma/weights.h" #include "util/threading.h" @@ -45,20 +48,18 @@ TEST(OptimizeTest, GradientDescent) { .training = ModelTraining::GEMMA_IT, .weight = Type::kF32, }; - ByteStorageT grad = CallForModelAndWeight( - info.model, info.weight, pool); - ByteStorageT grad_m = CallForModelAndWeight( - info.model, info.weight, pool); - ByteStorageT grad_v = CallForModelAndWeight( - info.model, info.weight, pool); - ByteStorageT forward = - CallForModelAndWeight(info.model, info.weight); - ByteStorageT backward = - CallForModelAndWeight(info.model, info.weight); - KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16); + ModelConfig config = ConfigFromModel(info.model); + ModelWeightsStorage grad, grad_m, grad_v; + grad.Allocate(info.model, info.weight, pool); + grad_m.Allocate(info.model, info.weight, pool); + grad_v.Allocate(info.model, info.weight, pool); + grad_m.ZeroInit(); + grad_v.ZeroInit(); + ForwardPass forward(config), backward(config); + KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); - RowVectorBatch inv_timescale = - Activations::CreateInvTimescale>(); + RowVectorBatch inv_timescale = Activations::CreateInvTimescale( + config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk); Gemma gemma(GemmaTokenizer(), info, pools); @@ -92,14 +93,11 @@ TEST(OptimizeTest, GradientDescent) { reply.begin() + context.size()); }; - RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen); - CallForModelAndWeight(info.model, info.weight, - grad_m, pool); - CallForModelAndWeight(info.model, info.weight, - grad_v, pool); + gemma.MutableWeights().RandInit(gen); + gemma.MutableWeights().AllocAndCopyWithTranspose(pool); printf("Initial weights:\n"); - LogWeightStats(info.model, info.weight, gemma.Weights()); + gemma.MutableWeights().LogWeightStats(); constexpr size_t kBatchSize = 8; const float alpha = 0.001f; @@ -113,29 +111,29 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - CallForModelAndWeight(info.model, info.weight, - grad, pool); + grad.ZeroInit(); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { Prompt prompt = training_task.Sample(sgen); total_loss += CrossEntropyLossForwardPass( - info.model, prompt, gemma.Weights(), forward, inv_timescale, pool); - CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward, - grad, backward, inv_timescale, pool); - CallForModelAndWeight( - info.model, info.weight, gemma.MutableWeights(), pool); + prompt, *gemma.Weights().GetWeightsOfType(), forward, + inv_timescale, pool); + CrossEntropyLossBackwardPass( + prompt, *gemma.Weights().GetWeightsOfType(), forward, + *grad.GetWeightsOfType(), backward, inv_timescale, pool); + gemma.MutableWeights().CopyWithTranspose(pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; - AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon, - steps + 1, gemma.Weights(), grad_m, grad_v, pool); + AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1, + gemma.Weights(), grad_m, grad_v, pool); printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", steps, total_loss, num_ok, kBatchSize); if (steps % 100 == 0) { printf("Batch gradient:\n"); - LogWeightStats(info.model, info.weight, grad); + grad.LogWeightStats(); } if (total_loss < 0.5f) { break; @@ -143,7 +141,7 @@ TEST(OptimizeTest, GradientDescent) { } printf("Num steps: %zu\n", steps); printf("Final weights:\n"); - LogWeightStats(info.model, info.weight, gemma.Weights()); + gemma.MutableWeights().LogWeightStats(); EXPECT_LT(steps, 300); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 800f2fa..9187bf7 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -16,7 +16,6 @@ #include "backprop/optimizer.h" #include -#include #include "compression/compress.h" #include "gemma/common.h" @@ -30,37 +29,6 @@ namespace gcpp { namespace { -class WeightInitializer { - public: - WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} - - void operator()(const char* name, hwy::Span tensors) { - float* data = tensors[0]->data(); - for (size_t i = 0; i < tensors[0]->NumElements(); ++i) { - data[i] = dist_(gen_); - } - tensors[0]->set_scale(1.0f); - } - - private: - std::normal_distribution dist_; - std::mt19937& gen_; -}; - -template -struct RandInitWeightsT { - void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool, - std::mt19937& gen) const { - auto& weights = - *reinterpret_cast*>(weights_u8.get()); - // TODO(szabadka) Use the same weight initialization method as in the python - // version. - WeightInitializer init(gen); - CompressedWeights::ForEachTensor({&weights}, - ForEachType::kLoadNoToc, init); - } -}; - class AdamUpdater { public: explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon, @@ -97,42 +65,31 @@ class AdamUpdater { float epsilon_; }; -template -struct AdamUpdateT { - void operator()(const ByteStorageT& grad_u8, float alpha, float beta1, - float beta2, float epsilon, size_t t, - const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, - const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { - using TWeights = CompressedWeights; - auto& grad = *reinterpret_cast(grad_u8.get()); - auto& weights = *reinterpret_cast(weights_u8.get()); - auto& grad_m = *reinterpret_cast(grad_m_u8.get()); - auto& grad_v = *reinterpret_cast(grad_v_u8.get()); - AdamUpdater updater(alpha, beta1, beta2, epsilon, t); - TWeights::ForEachTensor( - {&grad, &weights, &grad_m, &grad_v}, ForEachType::kLoadNoToc, - [&updater](const char* name, hwy::Span tensors) { - updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); - }); - } -}; +void AdamUpdate(ModelWeightsPtrs* grad, float alpha, float beta1, + float beta2, float epsilon, size_t t, + ModelWeightsPtrs* weights, + ModelWeightsPtrs* grad_m, + ModelWeightsPtrs* grad_v, hwy::ThreadPool& pool) { + AdamUpdater updater(alpha, beta1, beta2, epsilon, t); + ModelWeightsPtrs::ForEachTensor( + {grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc, + [&updater](const char* name, hwy::Span tensors) { + updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); + }); +} } // namespace -void RandInitWeights(Model model_type, Type weight_type, - const ByteStorageT& weights, hwy::ThreadPool& pool, - std::mt19937& gen) { +void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha, + float beta1, float beta2, float epsilon, size_t t, + const ModelWeightsStorage& weights, + const ModelWeightsStorage& grad_m, + const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) { HWY_ASSERT(weight_type == Type::kF32); - CallForModel(model_type, weights, pool, gen); -} - -void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, - float alpha, float beta1, float beta2, float epsilon, size_t t, - const ByteStorageT& weights, const ByteStorageT& grad_m, - const ByteStorageT& grad_v, hwy::ThreadPool& pool) { - HWY_ASSERT(weight_type == Type::kF32); - CallForModel(model_type, grad, alpha, beta1, beta2, - epsilon, t, weights, grad_m, grad_v, pool); + AdamUpdate(grad.GetWeightsOfType(), alpha, beta1, beta2, epsilon, t, + weights.GetWeightsOfType(), + grad_m.GetWeightsOfType(), grad_v.GetWeightsOfType(), + pool); } } // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h index b42f311..8b25c52 100644 --- a/backprop/optimizer.h +++ b/backprop/optimizer.h @@ -16,22 +16,17 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ -#include - #include "gemma/common.h" -#include "util/allocator.h" +#include "gemma/weights.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void RandInitWeights(Model model_type, Type weight_type, - const ByteStorageT& weights, hwy::ThreadPool& pool, - std::mt19937& gen); - -void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, - float alpha, float beta1, float beta2, float epsilon, size_t t, - const ByteStorageT& weights, const ByteStorageT& grad_m, - const ByteStorageT& grad_v, hwy::ThreadPool& pool); +void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha, + float beta1, float beta2, float epsilon, size_t t, + const ModelWeightsStorage& weights, + const ModelWeightsStorage& grad_m, + const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/test_util.h b/backprop/test_util.h index bfa2cc5..86f99b1 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -21,11 +21,12 @@ #include #include #include +#include #include "gtest/gtest.h" #include "compression/compress.h" +#include "gemma/configs.h" #include "gemma/weights.h" -#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -39,8 +40,8 @@ void RandInit(MatPtrT& x, T stddev, std::mt19937& gen) { } // TODO: make a member of Layer. -template -void RandInit(CompressedLayer& w, T stddev, std::mt19937& gen) { +template +void RandInit(LayerWeightsPtrs& w, T stddev, std::mt19937& gen) { RandInit(w.pre_attention_norm_scale, stddev, gen); RandInit(w.attn_vec_einsum_w, stddev, gen); RandInit(w.qkv_einsum_w, stddev, gen); @@ -49,9 +50,9 @@ void RandInit(CompressedLayer& w, T stddev, std::mt19937& gen) { RandInit(w.linear_w, stddev, gen); } -template -void RandInit(CompressedWeights& w, T stddev, std::mt19937& gen) { - static constexpr size_t kLayers = TConfig::kLayers; +template +void RandInit(ModelWeightsPtrs& w, T stddev, std::mt19937& gen) { + const size_t kLayers = w.c_layers.size(); RandInit(w.embedder_input_embedding, stddev, gen); RandInit(w.final_norm_scale, stddev, gen); for (size_t i = 0; i < kLayers; ++i) { @@ -66,9 +67,8 @@ void Complexify(const MatPtrT& x, MatPtrT>& c_x) { } } -template -void Complexify(const CompressedLayer& w, - CompressedLayer& c_w) { +template +void Complexify(const LayerWeightsPtrs& w, LayerWeightsPtrs& c_w) { Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); @@ -77,10 +77,9 @@ void Complexify(const CompressedLayer& w, Complexify(w.linear_w, c_w.linear_w); } -template -void Complexify(const CompressedWeights& w, - CompressedWeights& c_w) { - static constexpr size_t kLayers = TConfig::kLayers; +template +void Complexify(const ModelWeightsPtrs& w, ModelWeightsPtrs& c_w) { + const size_t kLayers = w.c_layers.size(); Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); Complexify(w.final_norm_scale, c_w.final_norm_scale); for (size_t i = 0; i < kLayers; ++i) { @@ -88,26 +87,27 @@ void Complexify(const CompressedWeights& w, } } -// Owns weights and provides access to TConfig. -template +// Somewhat duplicates ModelWeightsStorage, but that has neither double nor +// complex types allowed and it would cause code bloat to add them there. +template class WeightsWrapper { public: - WeightsWrapper() - : pool_(0), - data_(AllocateCompressedWeights()(pool_)), - weights_(reinterpret_cast*>(data_.get())) {} + explicit WeightsWrapper(const ModelConfig& config) + : pool_(0), weights_(config, pool_) { + weights_.Allocate(data_, pool_); + } - const CompressedWeights& get() const { return *weights_; } - CompressedWeights& get() { return *weights_; } - void ZeroInit() { weights_->ZeroInit(); } - void CopyFrom(const WeightsWrapper& other) { - get().CopyFrom(other.get()); + const ModelWeightsPtrs& get() const { return weights_; } + ModelWeightsPtrs& get() { return weights_; } + void ZeroInit() { weights_.ZeroInit(); } + void CopyFrom(const WeightsWrapper& other) { + weights_.CopyFrom(other.weights_); } private: hwy::ThreadPool pool_; - ByteStorageT data_; - CompressedWeights* weights_; + std::vector data_; + ModelWeightsPtrs weights_; }; template @@ -173,9 +173,9 @@ void TestGradient(const MatPtrT& grad, MatPtrT>& x, TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); } -template -void TestGradient(const CompressedLayer& grad, - CompressedLayer& c_weights, FUNC func, T max_err) { +template +void TestGradient(const LayerWeightsPtrs& grad, + LayerWeightsPtrs& c_weights, FUNC func, T max_err) { TestGradient(grad.pre_attention_norm_scale, c_weights.pre_attention_norm_scale, func, max_err, max_err, __LINE__); @@ -191,15 +191,15 @@ void TestGradient(const CompressedLayer& grad, func, max_err, max_err, __LINE__); } -template -void TestGradient(const CompressedWeights& grad, - CompressedWeights& c_weights, FUNC func, T max_err) { +template +void TestGradient(const ModelWeightsPtrs& grad, + ModelWeightsPtrs& c_weights, FUNC func, T max_err) { TestGradient(grad.embedder_input_embedding, c_weights.embedder_input_embedding, func, 2 * max_err, max_err, __LINE__); TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, func, max_err, max_err, __LINE__); - for (int i = 0; i < TConfig::kLayers; ++i) { + for (size_t i = 0; i < grad.c_layers.size(); ++i) { TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err); } } diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 24248a1..57f50f5 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include #include @@ -276,6 +275,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { [pfile, &requests, &err](uint64_t i, size_t /*thread*/) { if (!pfile->Read(requests[i].offset, requests[i].size, requests[i].data)) { + fprintf(stderr, "Failed to read blob %zu\n", i); err.test_and_set(); } }); diff --git a/compression/compress.h b/compression/compress.h index e0ea0d7..adb35a1 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -102,8 +102,8 @@ class CompressedArray { class MatPtr { public: // Full constructor for dynamic sizing. - MatPtr(const std::string& name, const std::string& type, size_t element_size, - size_t rows, size_t cols) + MatPtr(const std::string& name, Type type, size_t element_size, size_t rows, + size_t cols) : name_(name), type_(type), element_size_(element_size), @@ -129,7 +129,7 @@ class MatPtr { MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1, const hwy::uint128_t& key2, const hwy::uint128_t& key3) : name_(StringFromKey(key0)), - type_(StringFromKey(key1)), + type_(static_cast(key1.lo)), element_size_(key2.hi), num_elements_(key2.lo), rows_(key3.lo), @@ -138,7 +138,7 @@ class MatPtr { // Adds the contents entry to the table of contents. void AddToToc(std::vector& toc) const { toc.push_back(MakeKey(name_.c_str())); - toc.push_back(MakeKey(type_.c_str())); + toc.push_back({static_cast(type_), 0}); toc.push_back({num_elements_, element_size_}); toc.push_back({rows_, cols_}); } @@ -167,7 +167,7 @@ class MatPtr { void SetName(const std::string& name) { name_ = name; } // Returns the type of the blob. - const std::string& Type() const { return type_; } + Type GetType() const { return type_; } // Returns the size of each element in bytes. size_t ElementSize() const { return element_size_; } @@ -219,8 +219,8 @@ class MatPtr { protected: // Arbitrary name for the array of preferably <= 16 characters. std::string name_; - // Should be the result of TypeName for CallUpcasted() to work. - std::string type_; + // Should be the result of TypeEnum for CallUpcasted() to work. + Type type_; // sizeof(T) size_t element_size_ = 0; // Number of elements in the array. @@ -247,7 +247,7 @@ class MatPtrT : public MatPtr { // Full constructor for dynamic sizing. MatPtrT(const std::string& name, size_t rows, size_t cols) - : MatPtr(name, TypeName(), sizeof(MatT), rows, cols) {} + : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} // Copying allowed as the metadata is small. MatPtrT(const MatPtr& other) : MatPtr(other) {} @@ -330,17 +330,20 @@ class MatPtrT : public MatPtr { template decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { - if (type_ == TypeName()) { + if (type_ == TypeEnum()) { return func(dynamic_cast*>(this), std::forward(args)...); - } else if (type_ == TypeName()) { + } else if (type_ == TypeEnum()) { return func(dynamic_cast*>(this), std::forward(args)...); - } else if (type_ == TypeName()) { + } else if (type_ == TypeEnum()) { return func(dynamic_cast*>(this), std::forward(args)...); + } else if (type_ == TypeEnum()) { + return func(dynamic_cast*>(this), + std::forward(args)...); } else { - HWY_ABORT("Type %s unknown.", type_.c_str()); + HWY_ABORT("Type %d unknown.", type_); } } @@ -563,9 +566,10 @@ class CacheLoader { } // Returns whether all tensors are successfully loaded from cache. - bool ReadAll(hwy::ThreadPool& pool, std::vector& model_memory) { + BlobError ReadAll(hwy::ThreadPool& pool, + std::vector& model_memory) { // reader_ invalid or any Enqueue failed - if (err_ != 0) return false; + if (err_ != 0) return err_; // Setup the model_memory. for (int b = 0; b < model_toc_.size(); ++b) { const std::string& file_key = file_keys_[b]; @@ -574,12 +578,12 @@ class CacheLoader { const MatPtr* toc_blob = file_toc_.Get(file_key); if (toc_blob == nullptr) { fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str()); - return false; + return __LINE__; } if (toc_blob->Rows() != blob->Rows() || toc_blob->Cols() != blob->Cols()) { fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str()); - return false; + return __LINE__; } MatStorage toc_blob_array(*toc_blob); model_memory.push_back(std::move(toc_blob_array)); @@ -603,17 +607,10 @@ class CacheLoader { "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", blob.Name().c_str(), err_, blob.Rows(), blob.Cols(), blob.ElementSize()); - return false; + return err_; } } - - err_ = reader_.ReadAll(pool); - if (err_ != 0) { - fprintf(stderr, "Failed to read all tensors (error %d)\n", err_); - return false; - } - - return true; + return reader_.ReadAll(pool); } private: diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 51897af..1a4fc52 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -24,6 +24,7 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "gemma/configs.h" #ifndef GEMMA_COMPRESS_WEIGHTS_ONCE #define GEMMA_COMPRESS_WEIGHTS_ONCE @@ -150,29 +151,22 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template +template void CompressWeights(const Path& weights_path, const Path& compressed_weights_path, Model model_type, - Type weight_type, hwy::ThreadPool& pool) { + hwy::ThreadPool& pool) { if (!weights_path.Exists()) { HWY_ABORT("The model weights file '%s' does not exist.", weights_path.path.c_str()); } printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), compressed_weights_path.path.c_str()); - - using CConfig = typename Configs::c; - using UCConfig = typename Configs::uc; - // Allocate compressed weights. - using CWeights = CompressedWeights; - ByteStorageT c_weights_u8 = AllocateCompressedWeights()(pool); - CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); - - // Allocate uncompressed weights. - using UCWeights = CompressedWeights; - ByteStorageT uc_weights_u8 = AllocateCompressedWeights()(pool); - UCWeights* uc_weights = reinterpret_cast(uc_weights_u8.get()); - + ModelConfig config = ConfigFromModel(model_type); + std::vector model_storage; + ModelWeightsPtrs c_weights(config, pool); + c_weights.Allocate(model_storage, pool); + ModelWeightsPtrs uc_weights(config, pool); + uc_weights.Allocate(model_storage, pool); // Get uncompressed weights, compress, and store. FILE* fptr = fopen(weights_path.path.c_str(), "rb"); if (fptr == nullptr) { @@ -181,22 +175,22 @@ void CompressWeights(const Path& weights_path, } bool ok = true; uint64_t total_size = 0; - CompressedWeights::ForEachTensor( - {uc_weights}, ForEachType::kLoadNoToc, + ModelWeightsPtrs::ForEachTensor( + {&uc_weights}, ForEachType::kLoadNoToc, [&](const char* name, hwy::Span tensors) { fprintf(stderr, "Loading Parameters (size %zu): %s\n", tensors[0]->SizeBytes(), name); ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); total_size += tensors[0]->SizeBytes(); }); - const bool scale_for_compression = UCConfig::kNumTensorScales > 0; + const bool scale_for_compression = config.num_tensor_scales > 0; std::vector scales; if (scale_for_compression) { - uc_weights->GetOrApplyScales(scales); + uc_weights.GetOrApplyScales(scales); } Compressor compressor(pool); - CompressedWeights::ForEachTensor( - {reinterpret_cast*>(uc_weights), c_weights}, + ModelWeightsPtrs::ForEachTensor( + {reinterpret_cast*>(&uc_weights), &c_weights}, ForEachType::kLoadNoToc, [&compressor](const char* name, hwy::Span tensors) { tensors[1]->CallUpcasted( @@ -221,9 +215,26 @@ void Run(Args& args) { HWY_ABORT("PaliGemma is not supported in compress_weights."); } const Type weight_type = args.WeightType(); - GEMMA_EXPORT_AND_DISPATCH( - model_type, weight_type, CompressWeights, - (args.weights, args.compressed_weights, model_type, weight_type, pool)); + switch (weight_type) { + case Type::kF32: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + case Type::kBF16: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + case Type::kSFP: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + case Type::kNUQ: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) + (args.weights, args.compressed_weights, model_type, pool); + break; + default: + HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); + } } } // namespace gcpp diff --git a/compression/shared.h b/compression/shared.h index c216d24..74b7454 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -32,11 +32,6 @@ namespace gcpp { using BF16 = hwy::bfloat16_t; -template -constexpr bool IsF32() { - return hwy::IsSame, float>(); -} - // Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 // inputs that combines the advantages of e4m3 and e5m2 into a single format. // It supports seeking at a granularity of 1 and decoding to bf16/f32. @@ -179,29 +174,67 @@ struct NuqStream { }; #pragma pack(pop) +template +constexpr bool IsF32() { + return hwy::IsSame, float>(); +} + +template +constexpr bool IsBF16() { + return hwy::IsSame, BF16>(); +} + +template +constexpr bool IsSfpStream() { + return hwy::IsSame, SfpStream>(); +} + +template +constexpr bool IsNuqStream() { + return hwy::IsSame, NuqStream>(); +} + +// Instruction-tuned models require extra 'turn structure' tokens in prompts. +enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA }; + +// Tensor types for loading weights. Note that not all types are supported as +// weights for a model, but can be used for other purposes, such as types for +// ModelWeightsPtrs. When adding a new type that is supported, also +// update gemma.cc, weights.*, and add instantiations/new_one.cc. +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; +constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", + "nuq", "f64", "c64", "u128"}; + +// Returns a Type enum for the type of the template parameter. template -const char* TypeName() { +Type TypeEnum() { using Packed = hwy::RemoveCvRef; if constexpr (hwy::IsSame()) { - return "f32"; + return Type::kF32; } else if constexpr (hwy::IsSame()) { - return "b16"; + return Type::kBF16; } else if constexpr (hwy::IsSame()) { - return "sfp"; + return Type::kSFP; } else if constexpr (hwy::IsSame()) { - return "nuq"; + return Type::kNUQ; } else if constexpr (hwy::IsSame()) { - return "f64"; + return Type::kF64; } else if constexpr (hwy::IsSame>()) { - return "c64"; + return Type::kC64; } else if constexpr (hwy::IsSame()) { - return "u128"; + return Type::kU128; } else { HWY_DASSERT(false); - return "unknown"; + return Type::kUnknown; } } +// Returns a string name for the type of the template parameter. +template +const char* TypeName() { + return kTypeStrings[static_cast(TypeEnum())]; +} + template constexpr bool IsCompressed() { return hwy::IsSameEither, SfpStream, NuqStream>(); diff --git a/evals/benchmark.cc b/evals/benchmark.cc index b59079a..1ea4f65 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -128,8 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create( - env.GetModel()->Info().model, env.MutableConfig().prefill_tbatch_size); + KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(), + env.MutableConfig().prefill_tbatch_size); float entropy = ComputeCrossEntropy( *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 63553aa..abae040 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -69,8 +69,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, model_ = AllocateGemma(mutable_loader, pools_); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.resize(1); - kv_caches_[0] = - KVCache::Create(model_->Info().model, inference.prefill_tbatch_size); + kv_caches_[0] = KVCache::Create(model_->GetModelConfig(), + inference.prefill_tbatch_size); } InitGenerator(inference, gen_); runtime_config_ = { @@ -163,7 +163,7 @@ std::vector GemmaEnv::BatchQueryModel( } for (size_t i = 1; i < num_queries; ++i) { if (kv_caches_[i].seq_len == 0) { - kv_caches_[i] = KVCache::Create(model_->Info().model, + kv_caches_[i] = KVCache::Create(model_->GetModelConfig(), runtime_config_.prefill_tbatch_size); } } diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 870f84c..13ff3d3 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -103,8 +103,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, const StreamFunc stream_token = [](int /*token*/, float) { return true; }; // TWeight is unused, but we have to pass it to Config*. - const int vocab_size = - CallForModel(gemma.Info().model); + const int vocab_size = gemma.GetModelConfig().vocab_size; float cross_entropy = std::log(vocab_size); // first token size_t pos = 1; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 39d4f9c..2ed9b64 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -24,7 +24,6 @@ #include // Placeholder for internal header, do not modify. -#include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/tokenizer.h" #include "util/app.h" // LoaderArgs @@ -58,7 +57,8 @@ int main(int argc, char** argv) { gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::KVCache kv_cache = - gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); + gcpp::KVCache::Create(model.GetModelConfig(), + inference.prefill_tbatch_size); size_t generated = 0; // Initialize random number generator diff --git a/gemma/activations.h b/gemma/activations.h index b10b562..3983924 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -21,6 +21,7 @@ #include #include "compression/shared.h" // BF16 +#include "gemma/configs.h" #include "ops/matmul.h" // MatMulEnv #include "util/allocator.h" // RowVectorBatch #include "util/threading.h" @@ -30,6 +31,12 @@ namespace gcpp { struct Activations { + explicit Activations(const ModelConfig& config) + : weights_config(config), + layer_config(config.layer_configs[0]), + seq_len(config.seq_len), + cache_pos_size(config.CachePosSize()) {} + RowVectorBatch x; // input RowVectorBatch q; // query, also KV if MHA. RowVectorBatch logits; @@ -58,23 +65,24 @@ struct Activations { MatMulEnv env; + PostQKType post_qk = PostQKType::Rope; + // And the config. + const ModelConfig& weights_config; + const LayerConfig& layer_config; + size_t seq_len; + size_t cache_pos_size = 0; + // Multi-Head Attention? - template - static constexpr bool IsMHA() { - return TConfig::kHeads == TConfig::kKVHeads; - } + bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; } // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - template - static constexpr size_t QStride() { - return TConfig::kQKVDim * (IsMHA() ? 3 : 1); - } + size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); } - template - static RowVectorBatch CreateInvTimescale() { - constexpr size_t kQKVDim = TConfig::kQKVDim; - const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim; + static RowVectorBatch CreateInvTimescale(size_t qkv_dim, + PostQKType post_qk) { + const size_t rope_dim = + post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim; RowVectorBatch inv_timescale(1, rope_dim / 2); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { const float freq_exponents = @@ -86,40 +94,38 @@ struct Activations { return inv_timescale; } - template void Allocate(size_t batch_size, PerClusterPools& pools) { - constexpr size_t kModelDim = TConfig::kModelDim; - constexpr size_t kQKVDim = TConfig::kQKVDim; - constexpr size_t kHeads = TConfig::kHeads; - constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - constexpr size_t kVocabSize = TConfig::kVocabSize; - constexpr size_t kSeqLen = TConfig::kSeqLen; - constexpr size_t kGriffinLayers = TConfig::kGriffinLayers; + post_qk = layer_config.post_qk; + const size_t model_dim = weights_config.model_dim; + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + const size_t vocab_size = weights_config.vocab_size; - x = RowVectorBatch(batch_size, kModelDim); - q = RowVectorBatch(batch_size, kHeads * QStride()); - if constexpr (kVocabSize > 0) { - logits = RowVectorBatch(batch_size, kVocabSize); + x = RowVectorBatch(batch_size, model_dim); + q = RowVectorBatch(batch_size, layer_config.heads * QStride()); + if (vocab_size > 0) { + logits = RowVectorBatch(batch_size, vocab_size); } - pre_att_rms_out = RowVectorBatch(batch_size, kModelDim); - att = RowVectorBatch(batch_size, kHeads * kSeqLen); - att_out = RowVectorBatch(batch_size, kHeads * kQKVDim); - att_sums = RowVectorBatch(batch_size, kModelDim); + pre_att_rms_out = RowVectorBatch(batch_size, model_dim); + att = RowVectorBatch(batch_size, + layer_config.heads * weights_config.seq_len); + att_out = RowVectorBatch(batch_size, + layer_config.heads * layer_config.qkv_dim); + att_sums = RowVectorBatch(batch_size, model_dim); - bf_pre_ffw_rms_out = RowVectorBatch(batch_size, kModelDim); - C1 = RowVectorBatch(batch_size, kFFHiddenDim); - C2 = RowVectorBatch(batch_size, kFFHiddenDim); - ffw_out = RowVectorBatch(batch_size, kModelDim); + bf_pre_ffw_rms_out = RowVectorBatch(batch_size, model_dim); + C1 = RowVectorBatch(batch_size, ff_hidden_dim); + C2 = RowVectorBatch(batch_size, ff_hidden_dim); + ffw_out = RowVectorBatch(batch_size, model_dim); - if constexpr (kGriffinLayers > 0) { - griffin_x = RowVectorBatch(batch_size, kModelDim); - griffin_y = RowVectorBatch(batch_size, kModelDim); - griffin_gate_x = RowVectorBatch(batch_size, kModelDim); - griffin_multiplier = RowVectorBatch(batch_size, kModelDim); + if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { + griffin_x = RowVectorBatch(batch_size, model_dim); + griffin_y = RowVectorBatch(batch_size, model_dim); + griffin_gate_x = RowVectorBatch(batch_size, model_dim); + griffin_multiplier = RowVectorBatch(batch_size, model_dim); } - inv_timescale = CreateInvTimescale(); + inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk); env = MatMulEnv(pools); } diff --git a/gemma/common.cc b/gemma/common.cc index e68347b..447deb6 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -15,6 +15,7 @@ #include "gemma/common.h" +#include // sqrtf #include #include @@ -23,6 +24,7 @@ #include #include +#include "compression/shared.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -101,8 +103,6 @@ const char* ModelString(Model model, ModelTraining training) { static_cast(training)); } -constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"}; - const char* StringFromType(Type type) { return kTypeStrings[static_cast(type)]; } @@ -141,4 +141,19 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { prompt = start + prompt + "\nmodel\n"; } } + +float EmbeddingScaling(size_t model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + sqrtf(static_cast(model_dim)))); +} + +float ChooseQueryScale(const ModelConfig& config) { + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / sqrtf(static_cast(config.model_dim / + config.layer_configs[0].heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); +} + } // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h index 18ac5d1..e933e8d 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -16,37 +16,15 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ -#include // sqrtf #include #include -#include "compression/compress.h" #include "gemma/configs.h" // IWYU pragma: export #include "hwy/base.h" // ConvertScalarTo namespace gcpp { -// Model variants: see configs.h for details. When adding a new one, also -// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. -enum class Model { - GEMMA_2B, - GEMMA_7B, - GEMMA2_9B, - GEMMA2_27B, - GRIFFIN_2B, - GEMMA_TINY, - GEMMA2_2B, - PALIGEMMA_224, -}; - -// Instruction-tuned models require extra 'turn structure' tokens in prompts. -enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA }; - -// Tensor types for loading weights. When adding a new one, also -// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. -enum class Type { kF32, kBF16, kSFP }; - // TODO(janwas): merge with functions below. struct ModelInfo { Model model; @@ -66,198 +44,12 @@ const char* StringFromType(Type type); void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); -// Returns the return value of FuncT>().operator()(args), where -// Config* is selected via `model`. Typically called by CallForModelAndWeight, -// but can also be called directly when FuncT does not actually use TWeight. -// -// Note that a T prefix indicates a concrete type template argument, whereas a -// T suffix indicates the argument is itself a template. -// -// `FuncT` must be a functor because function templates cannot be passed as a -// template template argument, and we prefer to avoid the overhead of -// std::function. -template class FuncT, - typename... TArgs> -decltype(auto) CallForModel(Model model, TArgs&&... args) { - switch (model) { - case Model::GEMMA_TINY: - return FuncT>()(std::forward(args)...); - case Model::GEMMA_2B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA_7B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA2_9B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA2_27B: - return FuncT>()(std::forward(args)...); - case Model::GRIFFIN_2B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA2_2B: - return FuncT>()(std::forward(args)...); - case Model::PALIGEMMA_224: - return FuncT>()( - std::forward(args)...); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - -// Returns the return value of FuncT().operator()(args), -// where `TConfig` is selected based on `model` and `weight`. - -// This makes it easy to extend `Model` or `Type` without updating callers. -// -// Usage example: LoadWeights is type-erased so that it can be called from other -// .cc files. It uses this function to call the appropriate instantiation of a -// template functor LoadCompressedWeightsT. -template