diff --git a/BUILD.bazel b/BUILD.bazel index 022464c..eeac7cd 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -75,9 +75,7 @@ cc_library( ":allocator", ":threading", "//compression:compress", - "//compression:sfp", "@hwy//:algo", - "@hwy//:dot", "@hwy//:hwy", "@hwy//:math", "@hwy//:matvec", @@ -149,7 +147,6 @@ cc_test( "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", "@hwy//:thread_pool", ], ) @@ -281,11 +278,9 @@ cc_library( "//paligemma:image", "@hwy//:hwy", "@hwy//:bit_set", - "@hwy//:matvec", "@hwy//:nanobenchmark", # timer "@hwy//:profiler", "@hwy//:thread_pool", - "@hwy//:topology", ], ) @@ -481,6 +476,7 @@ cc_library( ":ops", ":prompt", ":weights", + "//compression:compress", "@hwy//:dot", "@hwy//:hwy", # base.h "@hwy//:thread_pool", @@ -498,9 +494,10 @@ cc_library( deps = [ ":allocator", ":common", - ":gemma_lib", ":prompt", - "//compression:weights_raw", + ":weights", + "//compression:compress", + "@hwy//:hwy", ], ) @@ -512,13 +509,15 @@ cc_test( "backprop/test_util.h", ], deps = [ + ":allocator", ":backprop_scalar", ":common", - ":gemma_lib", ":prompt", ":sampler", + ":weights", "@googletest//:gtest_main", - "//compression:weights_raw", + "//compression:compress", + "@hwy//:thread_pool", ], ) @@ -534,6 +533,7 @@ cc_test( "mem": "28g", }, deps = [ + ":allocator", ":backprop", ":backprop_scalar", ":common", @@ -541,8 +541,9 @@ cc_test( ":ops", ":prompt", ":sampler", + ":weights", "@googletest//:gtest_main", - "//compression:weights_raw", + "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:thread_pool", diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e1c70a..51e8891 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 457c891775a7397bdb0376bb1031e6e027af1c48 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bb6c3f36b0c8dde8a8ef98b0f0884f4de820a7ca EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if @@ -39,6 +39,7 @@ FetchContent_MakeAvailable(benchmark) set(SOURCES compression/blob_store.cc compression/blob_store.h + compression/compress.cc compression/compress.h compression/compress-inl.h compression/io_win.cc @@ -48,7 +49,6 @@ set(SOURCES compression/sfp-inl.h compression/shared.h compression/test_util-inl.h - compression/weights_raw.h backprop/activations.h backprop/backward.cc backprop/backward.h diff --git a/backprop/activations.h b/backprop/activations.h index aee0341..4f2e821 100644 --- a/backprop/activations.h +++ b/backprop/activations.h @@ -20,32 +20,51 @@ #include +#include "compression/compress.h" // MatStorageT #include "util/allocator.h" // ByteStorageT namespace gcpp { template struct ForwardLayer { + ForwardLayer() + : input("input", kSeqLen, kModelDim), + pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim), + qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim), + att("att", kSeqLen * kHeads, kSeqLen), + att_out("att_out", kSeqLen * kHeads, kQKVDim), + att_post1("att_post1", kSeqLen, kModelDim), + attention_out("attention_out", kSeqLen, kModelDim), + bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim), + ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2), + ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {} + static constexpr size_t kSeqLen = TConfig::kSeqLen; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - std::array input; - std::array pre_att_rms_out; - std::array qkv; - std::array att; - std::array att_out; - std::array att_post1; - std::array attention_out; - std::array bf_pre_ffw_rms_out; - std::array ffw_hidden; - std::array ffw_hidden_gated; + + MatStorageT input; + MatStorageT pre_att_rms_out; + MatStorageT qkv; + MatStorageT att; + MatStorageT att_out; + MatStorageT att_post1; + MatStorageT attention_out; + MatStorageT bf_pre_ffw_rms_out; + MatStorageT ffw_hidden; + MatStorageT ffw_hidden_gated; }; template struct ForwardPass { - ForwardPass() {} // prevents placement-new calling memset + ForwardPass() + : final_layer_output("final_layer_output", kSeqLen, kModelDim), + final_norm_output("final_norm_output", kSeqLen, kModelDim), + logits("logits", kSeqLen, kVocabSize), + probs("probs", kSeqLen, kVocabSize) { + } // prevents placement-new calling memset static constexpr size_t kSeqLen = TConfig::kSeqLen; static constexpr size_t kModelDim = TConfig::kModelDim; @@ -53,16 +72,20 @@ struct ForwardPass { static constexpr size_t kLayers = TConfig::kLayers; std::array, kLayers> layers; - std::array final_layer_output; - std::array final_norm_output; - std::array logits; - std::array probs; + MatStorageT final_layer_output; + MatStorageT final_norm_output; + MatStorageT logits; + MatStorageT probs; }; template struct AllocateForwardPass { ByteStorageT operator()() const { - return AllocateSizeof>(); + ByteStorageT c_weights_u8 = AllocateSizeof>(); + auto* c_weights = + reinterpret_cast*>(c_weights_u8.get()); + new (c_weights) ForwardPass(); + return c_weights_u8; } }; @@ -74,7 +97,7 @@ class ActivationsWrapper { public: ActivationsWrapper() : data_(AllocateSizeof()), - activations_(*reinterpret_cast(data_.get())) {} + activations_(*(new(data_.get()) WrappedT())) {} const WrappedT& get() const { return activations_; } WrappedT& get() { return activations_; } diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index df7c6fb..f765a5a 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -168,11 +168,11 @@ static HWY_NOINLINE void InputEmbeddingVJP( } } -template typename LayerT> -void LayerVJP(const LayerT& weights, +template +void LayerVJP(const LayerT& weights, const ForwardLayer& forward, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, - LayerT& grad, ForwardLayer& backward, + LayerT& grad, ForwardLayer& backward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { static constexpr size_t kModelDim = TConfig::kModelDim; @@ -226,8 +226,7 @@ void LayerVJP(const LayerT& weights, backward.attention_out.data() + pos * kModelDim, kModelDim); } - hwy::ZeroBytes(backward.qkv.data(), - num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0])); + backward.qkv.ZeroInit(); MultiHeadMatMulVJP( weights.attn_vec_einsum_w.data(), forward.att_out.data(), @@ -343,12 +342,10 @@ static HWY_NOINLINE void CrossEntropyLossGrad( } } -template typename WeightsT, - template typename LayerT> -void CrossEntropyLossBackwardPass(const Prompt& prompt, - const WeightsT& weights, +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, const ForwardPass& forward, - WeightsT& grad, + WeightsT& grad, ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { diff --git a/backprop/backward.cc b/backprop/backward.cc index da63e3a..c186952 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -52,7 +52,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, using TAct = ForwardPass; const auto& forward = *reinterpret_cast(forward_u8.get()); auto& backward = *reinterpret_cast(backward_u8.get()); - CrossEntropyLossBackwardPass( + CrossEntropyLossBackwardPass, + CompressedLayer>( prompt, weights, forward, grad, backward, inv_timescale, pool); } diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index 697d386..a804cd3 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -25,8 +25,8 @@ #include "backprop/activations.h" #include "backprop/common_scalar.h" #include "backprop/prompt.h" -#include "compression/weights_raw.h" #include "gemma/common.h" // EmbeddingScaling +#include "gemma/weights.h" namespace gcpp { template @@ -199,13 +199,11 @@ void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, } } -template -void LayerVJP(const Layer& weights, - const ForwardLayer& forward, - const T* dy, - Layer& grad, - ForwardLayer& backward, - size_t num_tokens) { +template +void LayerVJP(const CompressedLayer& weights, + const ForwardLayer& forward, const T* dy, + CompressedLayer& grad, + ForwardLayer& backward, size_t num_tokens) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kSeqLen = TConfig::kSeqLen; static constexpr size_t kQKVDim = TConfig::kQKVDim; @@ -298,11 +296,11 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { } } -template +template void CrossEntropyLossBackwardPass(const Prompt& prompt, - const Weights& weights, + const CompressedWeights& weights, const ForwardPass& forward, - Weights& grad, + CompressedWeights& grad, ForwardPass& backward) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 15aa876..262a121 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -17,7 +17,7 @@ #include #include -#include // memset +#include // memcpy #include #include @@ -32,8 +32,9 @@ #include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "compression/weights_raw.h" +#include "compression/compress.h" #include "gemma/configs.h" +#include "gemma/weights.h" namespace gcpp { @@ -44,14 +45,14 @@ TEST(BackPropTest, MatMulVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array weights; - std::array x; - std::array grad; - std::array dx; - std::array c_weights; - std::array c_x; - std::array c_y; - std::array dy; + MatStorageT weights("weights", kRows, kCols); + MatStorageT x("x", kTokens, kCols); + MatStorageT grad("grad", kRows, kCols); + MatStorageT dx("dx", kTokens, kCols); + MatStorageT c_weights("c_weights", kRows, kCols); + MatStorageT c_x("c_x", kTokens, kCols); + MatStorageT c_y("c_y", kTokens, kRows); + MatStorageT dy("dy", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); @@ -63,7 +64,7 @@ TEST(BackPropTest, MatMulVJP) { MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); return DotT(dy.data(), c_y.data(), kTokens * kRows); }; - memset(&grad, 0, sizeof(grad)); + grad.ZeroInit(); MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), kRows, kCols, kTokens); TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__); @@ -79,14 +80,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array weights; - std::array x; - std::array grad; - std::array dx; - std::array c_weights; - std::array c_x; - std::array c_y; - std::array dy; + MatStorageT weights("weights", kRows, kCols * kHeads); + MatStorageT x("x", kTokens, kCols * kHeads); + MatStorageT grad("grad", kRows, kCols * kHeads); + MatStorageT dx("dx", kTokens, kCols * kHeads); + MatStorageT c_weights("c_weights", kRows, kCols * kHeads); + MatStorageT c_x("c_x", kTokens, kCols * kHeads); + MatStorageT c_y("c_y", kTokens, kRows); + MatStorageT dy("dy", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); @@ -99,7 +100,7 @@ TEST(BackPropTest, MultiHeadMatMulVJP) { kCols, kTokens); return DotT(dy.data(), c_y.data(), kTokens * kRows); }; - memset(&grad, 0, sizeof(grad)); + grad.ZeroInit(); MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), kHeads, kRows, kCols, kTokens); TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__); @@ -113,14 +114,14 @@ TEST(BackPropTest, RMSNormVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array weights; - std::array grad; - std::array x; - std::array dx; - std::array dy; - std::array c_weights; - std::array c_x; - std::array c_y; + MatStorageT weights("weights", N, 1); + MatStorageT grad("grad", N, 1); + MatStorageT x("x", K, N); + MatStorageT dx("dx", K, N); + MatStorageT dy("dy", K, N); + MatStorageT c_weights("c_weights", N, 1); + MatStorageT c_x("c_x", K, N); + MatStorageT c_y("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); @@ -132,7 +133,7 @@ TEST(BackPropTest, RMSNormVJP) { RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); return DotT(dy.data(), c_y.data(), K * N); }; - memset(&grad, 0, sizeof(grad)); + grad.ZeroInit(); RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), N, K); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); @@ -145,23 +146,23 @@ TEST(BackPropTest, SoftmaxVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array x; - std::array dx; - std::array dy; - std::array c_x; - std::array c_y; + MatStorageT x("x", N, 1); + MatStorageT dx("dx", N, 1); + MatStorageT dy("dy", N, 1); + MatStorageT c_x("c_x", N, 1); + MatStorageT c_y("c_y", N, 1); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - memcpy(c_y.data(), c_x.data(), sizeof(c_x)); + memcpy(c_y.data(), c_x.data(), c_x.SizeBytes()); Softmax(c_y.data(), N); return DotT(dy.data(), c_y.data(), N); }; Softmax(x.data(), N); - memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); + memcpy(dx.data(), dy.data(), dx.SizeBytes()); SoftmaxVJPT(x.data(), dx.data(), N); TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); } @@ -171,15 +172,16 @@ TEST(BackPropTest, MaskedSoftmaxVJP) { static const size_t kSeqLen = 16; static const size_t kHeads = 2; static const size_t kTokens = 14; - static const size_t N = kHeads * kSeqLen * kSeqLen; + static const size_t N = kTokens * kHeads * kSeqLen; std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array x; - std::array dy; - std::array dx = {}; - std::array c_x; - std::array c_y; + MatStorageT x("x", N, 1); + MatStorageT dy("dy", N, 1); + MatStorageT dx("dx", N, 1); + MatStorageT c_x("c_x", N, 1); + MatStorageT c_y("c_y", N, 1); + dx.ZeroInit(); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); @@ -187,12 +189,12 @@ TEST(BackPropTest, MaskedSoftmaxVJP) { RandInit(dy, 1.0, gen); auto func = [&]() { memcpy(c_y.data(), c_x.data(), - kTokens * kHeads * kSeqLen * sizeof(c_x[0])); + kTokens * kHeads * kSeqLen * sizeof(c_x.At(0))); MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen); return DotT(dy.data(), c_y.data(), N); }; MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen); - memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx[0])); + memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0))); MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen); TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); } @@ -203,11 +205,11 @@ TEST(BackPropTest, SoftcapVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array x; - std::array dx; - std::array dy; - std::array c_x; - std::array c_y; + MatStorageT x("x", N, 1); + MatStorageT dx("dx", N, 1); + MatStorageT dy("dy", N, 1); + MatStorageT c_x("c_x", N, 1); + MatStorageT c_y("c_y", N, 1); constexpr float kCap = 30.0f; for (int iter = 0; iter < 10; ++iter) { @@ -215,12 +217,12 @@ TEST(BackPropTest, SoftcapVJP) { Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0])); + memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0))); Softcap(kCap, c_y.data(), N); return DotT(dy.data(), c_y.data(), N); }; Softcap(kCap, x.data(), N); - memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); + memcpy(dx.data(), dy.data(), dx.SizeBytes()); SoftcapVJPT(kCap, x.data(), dx.data(), N); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); } @@ -232,9 +234,9 @@ TEST(BackPropTest, CrossEntropyLossGrad) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array x; - std::array dx; - std::array c_x; + MatStorageT x("x", K, V); + MatStorageT dx("dx", K, V); + MatStorageT c_x("c_x", K, V); Prompt prompt; prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; @@ -259,11 +261,11 @@ TEST(BackPropTest, GatedGeluVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array x; - std::array dx; - std::array dy; - std::array c_x; - std::array c_y; + MatStorageT x("x", K, 2 * N); + MatStorageT dx("dx", K, 2 * N); + MatStorageT dy("dy", K, N); + MatStorageT c_x("c_x", K, 2 * N); + MatStorageT c_y("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0, gen); @@ -284,15 +286,17 @@ TEST(BackPropTest, MaskedAttentionVJP) { static const size_t kQKVDim = 8; static const size_t kTokens = 14; static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; - static const size_t kOutSize = kSeqLen * kHeads * kSeqLen; + static const size_t kOutSize = kTokens * kHeads * kSeqLen; std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array x; - std::array dx = {}; - std::array dy; - std::array c_x; - std::array c_y; + MatStorageT x("x", kQKVSize, 1); + MatStorageT dx("dx", kQKVSize, 1); + MatStorageT dy("dy", kOutSize, 1); + MatStorageT c_x("c_x", kQKVSize, 1); + MatStorageT c_y("c_y", kOutSize, 1); + dx.ZeroInit(); + c_y.ZeroInit(); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0, gen); @@ -320,14 +324,17 @@ TEST(BackPropTest, MixByAttentionVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array qkv; - std::array dqkv = {}; - std::array attn; - std::array dattn = {}; - std::array dy; - std::array c_qkv; - std::array c_attn; - std::array c_y; + MatStorageT qkv("qkv", kQKVSize, 1); + MatStorageT dqkv("dqkv", kQKVSize, 1); + MatStorageT attn("attn", kAttnSize, 1); + MatStorageT dattn("dattn", kAttnSize, 1); + MatStorageT dy("dy", kOutSize, 1); + MatStorageT c_qkv("c_qkv", kQKVSize, 1); + MatStorageT c_attn("c_attn", kAttnSize, 1); + MatStorageT c_y("c_y", kOutSize, 1); + dqkv.ZeroInit(); + dattn.ZeroInit(); + c_y.ZeroInit(); for (int iter = 0; iter < 10; ++iter) { RandInit(qkv, 1.0, gen); @@ -354,11 +361,11 @@ TEST(BackPropTest, InputEmbeddingVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - std::array weights; - std::array grad; - std::array dy; - std::array c_weights; - std::array c_y; + MatStorageT weights("weights", kVocabSize, kModelDim); + MatStorageT grad("grad", kVocabSize, kModelDim); + MatStorageT dy("dy", kSeqLen, kModelDim); + MatStorageT c_weights("c_weights", kVocabSize, kModelDim); + MatStorageT c_y("c_y", kSeqLen, kModelDim); std::vector tokens = { 0, 1, 2, 3, 0, 1, 2 }; size_t num_tokens = tokens.size() - 1; @@ -370,14 +377,16 @@ TEST(BackPropTest, InputEmbeddingVJP) { InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim); return DotT(dy.data(), c_y.data(), num_tokens * kModelDim); }; - memset(&grad, 0, sizeof(grad)); + grad.ZeroInit(); InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(), kModelDim); TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__); } } +template struct TestConfig : ConfigBaseGemmaV2 { + using Weight = T; static constexpr int kSeqLen = 18; static constexpr int kVocabSize = 12; static constexpr int kModelDim = 32; @@ -399,17 +408,21 @@ TEST(BackPropTest, LayerVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim; - Layer weights; - Layer grad; - ForwardLayer forward; - ForwardLayer backward = {}; - Layer c_weights; - ForwardLayer c_forward; + const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim; + CompressedLayer> weights; + CompressedLayer> grad; + ForwardLayer> forward; + ForwardLayer> backward = {}; + CompressedLayer> c_weights; + ForwardLayer> c_forward; std::array y; - std::array dy; + MatStorageT dy("dy", kOutputSize, 1); std::array c_y; const size_t num_tokens = 3; + weights.Allocate(); + grad.Allocate(); + c_weights.Allocate(); + backward.input.ZeroInit(); for (size_t iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0, gen); @@ -419,9 +432,9 @@ TEST(BackPropTest, LayerVJP) { Complexify(forward.input, c_forward.input); auto func = [&]() { ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); - return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim); + return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim); }; - memset(&grad, 0, sizeof(grad)); + grad.ZeroInit(/*layer_idx=*/0); ApplyLayer(weights, forward, num_tokens, y.data()); LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens); TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11, @@ -434,12 +447,12 @@ TEST(BackPropTest, EndToEnd) { std::mt19937 gen(42); using T = double; using TC = std::complex; - WeightsWrapper weights; - WeightsWrapper grad; - ForwardPass forward; - ForwardPass backward; - WeightsWrapper c_weights; - ForwardPass c_forward; + WeightsWrapper> weights; + WeightsWrapper> grad; + ForwardPass> forward; + ForwardPass> backward; + WeightsWrapper> c_weights; + ForwardPass> c_forward; ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); @@ -448,7 +461,7 @@ TEST(BackPropTest, EndToEnd) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0, gen); CrossEntropyLossForwardPass(prompt, weights.get(), forward); - grad.clear(); + grad.ZeroInit(); CrossEntropyLossBackwardPass( prompt, weights.get(), forward, grad.get(), backward); @@ -461,9 +474,9 @@ TEST(BackPropTest, EndToEnd) { } } -template -void MulByConstAndAddT(T c, const Layer& x, - Layer& out) { +template +void MulByConstAndAddT(T c, const CompressedLayer& x, + CompressedLayer& out) { MulByConstAndAddT(c, x.pre_attention_norm_scale, out.pre_attention_norm_scale); MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); @@ -473,9 +486,9 @@ void MulByConstAndAddT(T c, const Layer& x, MulByConstAndAddT(c, x.linear_w, out.linear_w); } -template -void MulByConstAndAddT(T c, const Weights& x, - Weights& out) { +template +void MulByConstAndAddT(T c, const CompressedWeights& x, + CompressedWeights& out) { static constexpr size_t kLayers = TConfig::kLayers; MulByConstAndAddT(c, x.embedder_input_embedding, out.embedder_input_embedding); @@ -486,9 +499,9 @@ void MulByConstAndAddT(T c, const Weights& x, } // Evaluates forward pass on a batch. -template +template T CrossEntropyLossForwardPass(const std::vector& batch, - const WeightsWrapper& weights, + const WeightsWrapper& weights, ForwardPass& forward) { T loss = 0.0; for (const Prompt& prompt : batch) { @@ -501,14 +514,13 @@ T CrossEntropyLossForwardPass(const std::vector& batch, // Evaluates forward pass on a batch by applying gradient with the given // learning rate. Does not update weights, but uses the given tmp weights // instead. -template -T CrossEntropyLossForwardPass(T learning_rate, - const std::vector& batch, - const WeightsWrapper& weights, - const WeightsWrapper& grad, - WeightsWrapper& tmp, +template +T CrossEntropyLossForwardPass(T learning_rate, const std::vector& batch, + const WeightsWrapper& weights, + const WeightsWrapper& grad, + WeightsWrapper& tmp, ForwardPass& forward) { - tmp.copy(weights); + tmp.CopyFrom(weights); const T scale = -learning_rate / batch.size(); MulByConstAndAddT(scale, grad.get(), tmp.get()); return CrossEntropyLossForwardPass(batch, tmp, forward); @@ -517,13 +529,13 @@ T CrossEntropyLossForwardPass(T learning_rate, // Uses line search in the negative gradient direction to update weights. We do // this so that we can test that each step during the gradient descent can // decrease the objective function value. -template -T FindOptimalUpdate(const WeightsWrapper& grad, - WeightsWrapper& weights, - WeightsWrapper& tmp, +template +T FindOptimalUpdate(const WeightsWrapper& grad, + WeightsWrapper& weights, + WeightsWrapper& tmp, ForwardPass& forward, - const std::vector& batch, - T loss, T initial_learning_rate) { + const std::vector& batch, T loss, + T initial_learning_rate) { T lr0 = initial_learning_rate; T loss0 = CrossEntropyLossForwardPass( lr0, batch, weights, grad, tmp, forward); @@ -556,13 +568,13 @@ TEST(BackProptest, Convergence) { std::mt19937 gen(42); using T = float; using TC = std::complex; - WeightsWrapper weights; - WeightsWrapper grad; - WeightsWrapper tmp; - ForwardPass forward; - ForwardPass backward; - WeightsWrapper c_weights; - ForwardPass c_forward; + WeightsWrapper> weights; + WeightsWrapper> grad; + WeightsWrapper> tmp; + ForwardPass> forward; + ForwardPass> backward; + WeightsWrapper> c_weights; + ForwardPass> c_forward; constexpr size_t kBatchSize = 5; ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); T learning_rate = 0.01; @@ -579,7 +591,7 @@ TEST(BackProptest, Convergence) { size_t step = 0; while (!stop) { T loss = 0.0; - grad.clear(); + grad.ZeroInit(); std::mt19937 sgen(42); std::vector batch = training_task.SampleBatch(kBatchSize, sgen); for (const Prompt& prompt : batch) { diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index decec20..01c5e73 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -32,9 +32,9 @@ #include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "compression/weights_raw.h" #include "gemma/activations.h" #include "gemma/configs.h" +#include "gemma/weights.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -48,6 +48,7 @@ // After highway.h #include "backprop/backward-inl.h" #include "backprop/forward-inl.h" +#include "compression/compress.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); @@ -60,17 +61,17 @@ void TestMatMulVJP() { static const size_t kTokens = 5; hwy::ThreadPool pool(8); std::mt19937 gen(42); - HWY_ALIGN std::array weights; - HWY_ALIGN std::array x; - HWY_ALIGN std::array dy; - HWY_ALIGN std::array grad; - HWY_ALIGN std::array dx; - HWY_ALIGN std::array grad_scalar; - HWY_ALIGN std::array dx_scalar; + MatStorageT weights("weights", kRows, kCols); + MatStorageT x("x", kTokens, kCols); + MatStorageT dy("dy", kTokens, kRows); + MatStorageT grad("grad", kRows, kCols); + MatStorageT dx("dx", kTokens, kCols); + MatStorageT grad_scalar("grad_scalar", kRows, kCols); + MatStorageT dx_scalar("dx_scalar", kTokens, kCols); using TC = std::complex; - std::array c_weights; - std::array c_x; - std::array c_y; + MatStorageT c_weights("c_weights", kRows, kCols); + MatStorageT c_x("c_x", kTokens, kCols); + MatStorageT c_y("c_y", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0f * (1 << iter), gen); @@ -83,13 +84,13 @@ void TestMatMulVJP() { return DotT(dy.data(), c_y.data(), kTokens * kRows); }; - hwy::ZeroBytes(&grad, sizeof(grad)); + grad.ZeroInit(); MatMulVJP(weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(), pool); - TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); - TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); + TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); + grad_scalar.ZeroInit(); MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), kRows, kCols, kTokens); TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__); @@ -104,17 +105,17 @@ void TestMultiHeadMatMulVJP() { static const size_t kTokens = 3; hwy::ThreadPool pool(8); std::mt19937 gen(42); - HWY_ALIGN std::array weights; - HWY_ALIGN std::array x; - HWY_ALIGN std::array grad; - HWY_ALIGN std::array dx; - HWY_ALIGN std::array dy; - HWY_ALIGN std::array grad_scalar; - HWY_ALIGN std::array dx_scalar; + MatStorageT weights("weights", kRows, kCols * kHeads); + MatStorageT x("x", kTokens, kCols * kHeads); + MatStorageT grad("grad", kRows, kCols * kHeads); + MatStorageT dx("dx", kTokens, kCols * kHeads); + MatStorageT dy("dy", kTokens, kRows); + MatStorageT grad_scalar("grad_scalar", kRows, kCols * kHeads); + MatStorageT dx_scalar("dx_scalar", kTokens, kCols * kHeads); using TC = std::complex; - std::array c_weights; - std::array c_x; - std::array c_y; + MatStorageT c_weights("c_weights", kRows, kCols * kHeads); + MatStorageT c_x("c_x", kTokens, kCols * kHeads); + MatStorageT c_y("c_y", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0f * (1 << iter), gen); @@ -128,14 +129,14 @@ void TestMultiHeadMatMulVJP() { return DotT(dy.data(), c_y.data(), kTokens * kRows); }; - hwy::ZeroBytes(&grad, sizeof(grad)); + grad.ZeroInit(); MultiHeadMatMulVJP( weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(), pool); - TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); - TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); + TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); + grad_scalar.ZeroInit(); MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), kHeads, kRows, kCols, kTokens); TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); @@ -148,17 +149,17 @@ void TestRMSNormVJP() { static const size_t N = 64; hwy::ThreadPool pool(8); std::mt19937 gen(42); - HWY_ALIGN std::array weights; - HWY_ALIGN std::array x; - HWY_ALIGN std::array grad; - HWY_ALIGN std::array dx; - HWY_ALIGN std::array dy; - HWY_ALIGN std::array grad_scalar; - HWY_ALIGN std::array dx_scalar; + MatStorageT weights("weights", N, 1); + MatStorageT x("x", K, N); + MatStorageT grad("grad", N, 1); + MatStorageT dx("dx", K, N); + MatStorageT dy("dy", K, N); + MatStorageT grad_scalar("grad_scalar", N, 1); + MatStorageT dx_scalar("dx_scalar", K, N); using TC = std::complex; - std::array c_weights; - std::array c_x; - std::array c_y; + MatStorageT c_weights("c_weights", N, 1); + MatStorageT c_x("c_x", K, N); + MatStorageT c_y("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0f * (1 << iter), gen); @@ -171,13 +172,13 @@ void TestRMSNormVJP() { return DotT(dy.data(), c_y.data(), K * N); }; - hwy::ZeroBytes(&grad, sizeof(grad)); + grad.ZeroInit(); RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), dx.data(), pool); - TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); - TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); + TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); + grad_scalar.ZeroInit(); RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), N, K); TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); @@ -185,7 +186,9 @@ void TestRMSNormVJP() { } } -struct TestConfig : public ConfigBaseGemmaV2 { +template +struct TestConfig : ConfigBaseGemmaV2 { + using Weight = T; static constexpr int kSeqLen = 24; static constexpr int kVocabSize = 16; static constexpr int kModelDim = 32; @@ -206,20 +209,22 @@ struct TestConfig : public ConfigBaseGemmaV2 { void TestEndToEnd() { std::mt19937 gen(42); hwy::ThreadPool pool(0); - WeightsWrapper weights; - WeightsWrapper grad; - ActivationsWrapper forward0; - ActivationsWrapper forward1; - ActivationsWrapper backward; + using WeightsF = CompressedWeights>; + using LayerF = CompressedLayer>; + WeightsWrapper> weights; + WeightsWrapper> grad; + ActivationsWrapper> forward0; + ActivationsWrapper> forward1; + ActivationsWrapper> backward; using TC = std::complex; - WeightsWrapper c_weights; - ForwardPass c_forward; + WeightsWrapper> c_weights; + ForwardPass> c_forward; ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); RowVectorBatch inv_timescale = - Activations::CreateInvTimescale(); + Activations::CreateInvTimescale>(); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0f, gen); @@ -227,14 +232,15 @@ void TestEndToEnd() { float loss0 = CrossEntropyLossForwardPass( prompt, weights.get(), forward0.get()); - float loss1 = CrossEntropyLossForwardPass( - prompt.tokens, prompt.context_size, weights.get(), forward1.get(), - inv_timescale, pool); + float loss1 = + CrossEntropyLossForwardPass, WeightsF, LayerF>( + prompt.tokens, prompt.context_size, weights.get(), forward1.get(), + inv_timescale, pool); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); - grad.clear(); - CrossEntropyLossBackwardPass( + grad.ZeroInit(); + CrossEntropyLossBackwardPass, WeightsF, LayerF>( prompt, weights.get(), forward1.get(), grad.get(), backward.get(), inv_timescale, pool); diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h index 034520d..c61086d 100644 --- a/backprop/common_scalar.h +++ b/backprop/common_scalar.h @@ -18,9 +18,10 @@ #include -#include #include +#include "compression/compress.h" // MatStorageT + namespace gcpp { template @@ -57,9 +58,9 @@ void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { } } -template -void MulByConstAndAddT(T c, const std::array& x, std::array& out) { - MulByConstAndAddT(c, x.data(), out.data(), N); +template +void MulByConstAndAddT(T c, const MatPtrT& x, MatPtrT& out) { + MulByConstAndAddT(c, x.data(), out.data(), x.NumElements()); } template diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index c799cf4..b6b1dc0 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -93,8 +93,8 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, return loss * scaling; } -template typename LayerT> -void ApplyForwardLayer(const LayerT& weights, +template +void ApplyForwardLayer(const LayerT& weights, ForwardLayer& activations, size_t num_tokens, float* HWY_RESTRICT output, const RowVectorBatch& inv_timescale, @@ -171,8 +171,7 @@ void ApplyForwardLayer(const LayerT& weights, } }); - hwy::ZeroBytes(activations.attention_out.data(), - num_tokens * kModelDim * sizeof(activations.attention_out[0])); + activations.attention_out.ZeroInit(); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t head = 0; head < kHeads; ++head) { MatVec( @@ -227,11 +226,9 @@ void ApplyForwardLayer(const LayerT& weights, } } -template typename WeightsT, - template typename LayerT> +template float CrossEntropyLossForwardPass(const std::vector& prompt, - size_t context_size, - const WeightsT& weights, + size_t context_size, const WeightsT& weights, ForwardPass& forward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { @@ -281,7 +278,7 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, } hwy::CopyBytes(forward.logits.data(), forward.probs.data(), - num_tokens * kVocabSize * sizeof(forward.logits[0])); + num_tokens * kVocabSize * sizeof(forward.logits.At(0))); for (size_t pos = 0; pos < num_tokens; ++pos) { Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); diff --git a/backprop/forward.cc b/backprop/forward.cc index 29721d2..5b2cf1a 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -46,8 +46,8 @@ float CrossEntropyLossForwardPass(const Prompt& prompt, *reinterpret_cast*>(weights_u8.get()); auto& forward = *reinterpret_cast*>(forward_u8.get()); - return CrossEntropyLossForwardPass( + return CrossEntropyLossForwardPass, + CompressedLayer>( prompt.tokens, prompt.context_size, weights, forward, inv_timescale, pool); } diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 5e33d1d..064112b 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -26,8 +26,9 @@ #include "backprop/activations.h" #include "backprop/common_scalar.h" #include "backprop/prompt.h" -#include "compression/weights_raw.h" #include "gemma/common.h" // EmbeddingScaling +#include "gemma/weights.h" +#include "hwy/base.h" namespace gcpp { @@ -116,6 +117,8 @@ void GatedGelu(const T* in, T* out, size_t N, size_t K) { template void InputEmbedding(const T* w, const std::vector& tokens, T scaling, T* y, size_t N) { + HWY_ASSERT(w != nullptr); + HWY_ASSERT(y != nullptr); const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; for (size_t i = 0; i < num_tokens; ++i) { int token = tokens[i]; @@ -166,10 +169,10 @@ void MixByAttention(const T* qkv, const T* attention, T* output, } } } -template -void ApplyLayer(const Layer& weights, - ForwardLayer& activations, - size_t num_tokens, T* output) { +template +void ApplyLayer(const CompressedLayer& weights, + ForwardLayer& activations, size_t num_tokens, + T* output) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kSeqLen = TConfig::kSeqLen; static constexpr size_t kQKVDim = TConfig::kQKVDim; @@ -244,9 +247,9 @@ T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { return loss * scaling; } -template +template T CrossEntropyLossForwardPass(const Prompt& prompt, - const Weights& weights, + const CompressedWeights& weights, ForwardPass& forward) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kVocabSize = TConfig::kVocabSize; @@ -282,7 +285,7 @@ T CrossEntropyLossForwardPass(const Prompt& prompt, } memcpy(forward.probs.data(), forward.logits.data(), - num_tokens * kVocabSize * sizeof(forward.logits[0])); + num_tokens * kVocabSize * sizeof(forward.logits.At(0))); Softmax(forward.probs.data(), kVocabSize, num_tokens); return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize); diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index ed51183..800f2fa 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -21,6 +21,8 @@ #include "compression/compress.h" #include "gemma/common.h" #include "gemma/weights.h" +#include "util/allocator.h" +#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -32,14 +34,14 @@ class WeightInitializer { public: WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} - template - void operator()(const char* name, CompressedArray& tensor) { - float* data = tensor.data(); - for (size_t i = 0; i < N; ++i) { + void operator()(const char* name, hwy::Span tensors) { + float* data = tensors[0]->data(); + for (size_t i = 0; i < tensors[0]->NumElements(); ++i) { data[i] = dist_(gen_); } - tensor.set_scale(1.0f); + tensors[0]->set_scale(1.0f); } + private: std::normal_distribution dist_; std::mt19937& gen_; @@ -54,7 +56,8 @@ struct RandInitWeightsT { // TODO(szabadka) Use the same weight initialization method as in the python // version. WeightInitializer init(gen); - ForEachTensor1(init, weights); + CompressedWeights::ForEachTensor({&weights}, + ForEachType::kLoadNoToc, init); } }; @@ -66,17 +69,13 @@ class AdamUpdater { cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))), norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {} - template - void operator()(const char* name, - const CompressedArray& grad, - CompressedArray& weights, - CompressedArray& grad_m, - CompressedArray& grad_v) { - const float* HWY_RESTRICT g = grad.data(); - float* HWY_RESTRICT w = weights.data(); - float* HWY_RESTRICT m = grad_m.data(); - float* HWY_RESTRICT v = grad_v.data(); - for (size_t i = 0; i < kCapacity; ++i) { + void operator()(const char* name, const MatPtr& grad, MatPtr& weights, + MatPtr& grad_m, MatPtr& grad_v) { + const float* HWY_RESTRICT g = grad.data(); + float* HWY_RESTRICT w = weights.data(); + float* HWY_RESTRICT m = grad_m.data(); + float* HWY_RESTRICT v = grad_v.data(); + for (size_t i = 0; i < grad.NumElements(); ++i) { m[i] *= beta1_; m[i] += cbeta1_ * g[i]; v[i] *= beta2_; @@ -105,12 +104,16 @@ struct AdamUpdateT { const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { using TWeights = CompressedWeights; - const auto& grad = *reinterpret_cast(grad_u8.get()); + auto& grad = *reinterpret_cast(grad_u8.get()); auto& weights = *reinterpret_cast(weights_u8.get()); auto& grad_m = *reinterpret_cast(grad_m_u8.get()); auto& grad_v = *reinterpret_cast(grad_v_u8.get()); AdamUpdater updater(alpha, beta1, beta2, epsilon, t); - ForEachTensor4(updater, grad, weights, grad_m, grad_v); + TWeights::ForEachTensor( + {&grad, &weights, &grad_m, &grad_v}, ForEachType::kLoadNoToc, + [&updater](const char* name, hwy::Span tensors) { + updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); + }); } }; diff --git a/backprop/test_util.h b/backprop/test_util.h index ef257e7..bfa2cc5 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -18,27 +18,57 @@ #include -#include #include #include +#include #include "gtest/gtest.h" -#include "compression/weights_raw.h" +#include "compression/compress.h" +#include "gemma/weights.h" +#include "util/allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -template -void Complexify(const std::array& x, - std::array, kLen>& c_x) { - for (size_t i = 0; i < kLen; ++i) { - c_x[i] = std::complex(x[i], 0.0); +template +void RandInit(MatPtrT& x, T stddev, std::mt19937& gen) { + std::normal_distribution dist(0.0, stddev); + for (size_t i = 0; i < x.NumElements(); ++i) { + x.At(i) = dist(gen); } } +// TODO: make a member of Layer. +template +void RandInit(CompressedLayer& w, T stddev, std::mt19937& gen) { + RandInit(w.pre_attention_norm_scale, stddev, gen); + RandInit(w.attn_vec_einsum_w, stddev, gen); + RandInit(w.qkv_einsum_w, stddev, gen); + RandInit(w.pre_ffw_norm_scale, stddev, gen); + RandInit(w.gating_einsum_w, stddev, gen); + RandInit(w.linear_w, stddev, gen); +} -template -void Complexify(const Layer& w, - Layer, TConfig>& c_w) { +template +void RandInit(CompressedWeights& w, T stddev, std::mt19937& gen) { + static constexpr size_t kLayers = TConfig::kLayers; + RandInit(w.embedder_input_embedding, stddev, gen); + RandInit(w.final_norm_scale, stddev, gen); + for (size_t i = 0; i < kLayers; ++i) { + RandInit(*w.GetLayer(i), stddev, gen); + } +} + +template +void Complexify(const MatPtrT& x, MatPtrT>& c_x) { + for (size_t i = 0; i < x.NumElements(); ++i) { + c_x.At(i) = std::complex(x.At(i), 0.0); + } +} + +template +void Complexify(const CompressedLayer& w, + CompressedLayer& c_w) { Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); @@ -47,9 +77,9 @@ void Complexify(const Layer& w, Complexify(w.linear_w, c_w.linear_w); } -template -void Complexify(const Weights& w, - Weights, TConfig>& c_w) { +template +void Complexify(const CompressedWeights& w, + CompressedWeights& c_w) { static constexpr size_t kLayers = TConfig::kLayers; Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); Complexify(w.final_norm_scale, c_w.final_norm_scale); @@ -58,19 +88,41 @@ void Complexify(const Weights& w, } } -template -void TestNear(const std::array& actual, const std::array& expected, +// Owns weights and provides access to TConfig. +template +class WeightsWrapper { + public: + WeightsWrapper() + : pool_(0), + data_(AllocateCompressedWeights()(pool_)), + weights_(reinterpret_cast*>(data_.get())) {} + + const CompressedWeights& get() const { return *weights_; } + CompressedWeights& get() { return *weights_; } + void ZeroInit() { weights_->ZeroInit(); } + void CopyFrom(const WeightsWrapper& other) { + get().CopyFrom(other.get()); + } + + private: + hwy::ThreadPool pool_; + ByteStorageT data_; + CompressedWeights* weights_; +}; + +template +void TestNear(const MatPtrT& actual, const MatPtrT& expected, double max_abs_err, double max_rel_err, int line) { double sum0 = 0; double sum1 = 0; double sum01 = 0; - for (size_t i = 0; i < N; ++i) { - sum0 += actual[i] * actual[i]; - sum1 += expected[i] * expected[i]; - sum01 += actual[i] * expected[i]; - ASSERT_NEAR(actual[i], expected[i], - std::max(max_abs_err, std::abs(expected[i]) * max_rel_err)) - << "line: " << line << " dim=" << N << " i=" << i; + for (size_t i = 0; i < actual.NumElements(); ++i) { + sum0 += actual.At(i) * actual.At(i); + sum1 += expected.At(i) * expected.At(i); + sum01 += actual.At(i) * expected.At(i); + ASSERT_NEAR(actual.At(i), expected.At(i), + std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err)) + << "line: " << line << " dim=" << expected.NumElements() << " i=" << i; } if (sum0 > 1e-40) { double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1); @@ -93,48 +145,37 @@ void TestNear(const std::array& actual, const std::array& expected, // This method is more numerically stable than the real-valued finite difference // method since we don't need to subtract floating point numbers that are near // to each other. -template -void TestGradient(const std::array& grad, - std::array, N>& x, FUNC func, - U step, T max_abs_err, T max_rel_err, int line) { - std::array exp_grad; +template +void TestGradient(const MatPtrT& grad, MatPtrT>& x, + FUNC func, U step, T max_abs_err, T max_rel_err, int line) { + MatStorageT exp_grad("exp_grad", x.Rows(), x.Cols()); const U inv_step = 1.0 / step; - for (size_t i = 0; i < N; ++i) { - const U x0 = std::real(x[i]); + for (size_t i = 0; i < x.NumElements(); ++i) { + const U x0 = std::real(x.At(i)); const std::complex x1 = std::complex(x0, step); - x[i] = x1; + x.At(i) = x1; const std::complex f1 = func(); - exp_grad [i] = std::imag(f1) * inv_step; - x[i] = x0; + exp_grad.At(i) = std::imag(f1) * inv_step; + x.At(i) = x0; } TestNear(grad, exp_grad, max_abs_err, max_rel_err, line); } -template -void TestGradient(const std::array& grad, - std::array, N>& x, FUNC func, - float max_abs_err, float max_rel_error, int line) { +template +void TestGradient(const MatPtrT& grad, MatPtrT>& x, + FUNC func, float max_abs_err, float max_rel_error, int line) { TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line); } -template -void TestGradient(const std::array& grad, - std::array, N>& x, FUNC func, - float max_abs_err, float max_rel_error, int line) { +template +void TestGradient(const MatPtrT& grad, MatPtrT>& x, + FUNC func, T max_abs_err, T max_rel_error, int line) { TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); } -template -void TestGradient(const std::array& grad, - std::array, N>& x, FUNC func, - double max_abs_err, double max_rel_error, int line) { - TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); -} - -template -void TestGradient(const Layer& grad, - Layer, TConfig>& c_weights, - FUNC func, T max_err) { +template +void TestGradient(const CompressedLayer& grad, + CompressedLayer& c_weights, FUNC func, T max_err) { TestGradient(grad.pre_attention_norm_scale, c_weights.pre_attention_norm_scale, func, max_err, max_err, __LINE__); @@ -150,10 +191,9 @@ void TestGradient(const Layer& grad, func, max_err, max_err, __LINE__); } -template -void TestGradient(const Weights& grad, - Weights, TConfig>& c_weights, - FUNC func, T max_err) { +template +void TestGradient(const CompressedWeights& grad, + CompressedWeights& c_weights, FUNC func, T max_err) { TestGradient(grad.embedder_input_embedding, c_weights.embedder_input_embedding, func, 2 * max_err, max_err, __LINE__); diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index c826d43..32af763 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -152,6 +152,7 @@ cc_test( cc_library( name = "compress", + srcs = ["compress.cc"], hdrs = [ "compress.h", "shared.h", @@ -207,30 +208,17 @@ cc_library( ], ) -cc_library( - name = "weights_raw", - hdrs = ["weights_raw.h"], - deps = [ - "//:allocator", - "//:common", - "@hwy//:hwy", - "@hwy//:thread_pool", - ], -) - cc_binary( name = "compress_weights", srcs = ["compress_weights.cc"], deps = [ ":compress", ":io", - ":weights_raw", "//:allocator", "//:args", "//:common", "//:weights", "@hwy//:hwy", - "@hwy//:profiler", "@hwy//:thread_pool", ], ) diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 13c9563..24248a1 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -19,7 +19,10 @@ #include #include +#include #include +#include +#include #include #include "compression/io.h" @@ -45,6 +48,13 @@ hwy::uint128_t MakeKey(const char* string) { return ret; } +std::string StringFromKey(hwy::uint128_t key) { + std::string name(sizeof(key) + 1, '\0'); + hwy::CopyBytes(&key, name.data(), sizeof(key)); + name.resize(name.find('\0')); + return name; +} + namespace { void EnqueueChunkRequests(uint64_t offset, uint64_t size, uint8_t* data, std::vector& requests) { @@ -226,15 +236,23 @@ BlobError BlobReader::Open(const Path& filename) { return blob_store_->CheckValidity(file_->FileSize()); } +size_t BlobReader::BlobSize(hwy::uint128_t key) const { + uint64_t offset; + size_t size; + if (!blob_store_->FindKey(key, offset, size)) return 0; + return size; +} + BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { uint64_t offset; size_t actual_size; if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; if (actual_size != size) { fprintf(stderr, - "Mismatch between expected %d and actual %d KiB size. Please see " - "README.md on how to update the weights.\n", - static_cast(size >> 10), static_cast(actual_size >> 10)); + "Mismatch between expected %d and actual %d KiB size of blob %s. " + "Please see README.md on how to update the weights.\n", + static_cast(size >> 10), static_cast(actual_size >> 10), + StringFromKey(key).c_str()); return __LINE__; } @@ -265,6 +283,17 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) { return 0; } +BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data, + size_t size) const { + uint64_t offset; + size_t actual_size; + if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; + if (!file_->Read(offset, actual_size, data)) { + return __LINE__; + } + return 0; +} + BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { HWY_ASSERT(keys_.size() == blobs_.size()); diff --git a/compression/blob_store.h b/compression/blob_store.h index c881564..4aba006 100644 --- a/compression/blob_store.h +++ b/compression/blob_store.h @@ -20,6 +20,7 @@ #include #include +#include #include #include "compression/io.h" @@ -32,6 +33,9 @@ namespace gcpp { // Convenient way to construct a key from a string (<= 16 chars). hwy::uint128_t MakeKey(const char* string); +// Returns a string from a key. +std::string StringFromKey(hwy::uint128_t key); + // Ordered list of opaque blobs (~hundreds), identified by unique opaque // 128-bit keys. class BlobStore; @@ -67,6 +71,9 @@ class BlobReader { // Opens `filename` and reads its header. BlobError Open(const Path& filename); + // Returns the size of the blob identified by `key`, or 0 if not found. + size_t BlobSize(hwy::uint128_t key) const; + // Enqueues read requests if `key` is found and its size matches `size`, which // is in units of bytes. BlobError Enqueue(hwy::uint128_t key, void* data, size_t size); @@ -74,6 +81,9 @@ class BlobReader { // Reads all enqueued requests. BlobError ReadAll(hwy::ThreadPool& pool); + // Reads one blob directly. + BlobError ReadOne(hwy::uint128_t key, void* data, size_t size) const; + private: BlobStorePtr blob_store_; // holds header, not the entire file std::vector requests_; diff --git a/compression/compress-inl.h b/compression/compress-inl.h index e47a979..79f8b40 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -471,14 +471,15 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, } } -// Adapter that compresses into `CompressedArray`. `raw` must already be scaled +// Adapter that compresses into `MatStorageT`. `raw` must already be scaled // to fit the value range, if `Packed` is `SfpStream`. -template +template HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num, CompressWorkingSet& work, - CompressedArray& compressed, + MatStorageT& compressed, hwy::ThreadPool& pool) { - Compress(raw, num, work, MakeSpan(compressed.data(), kCapacity), + Compress(raw, num, work, + MakeSpan(compressed.data(), compressed.NumElements()), /*packed_ofs=*/0, pool); } @@ -674,28 +675,24 @@ class Compressor { public: explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {} - template - void operator()(const char* name, const float* HWY_RESTRICT weights, - CompressedArray& compressed) { - Insert(name, weights, kCapacity, work_, compressed.GetSpan(), - /*packed_ofs=*/0, pool_); - } - template - void Insert(const char* name, const float* HWY_RESTRICT weights, - size_t num_weights, CompressWorkingSet& work, - const PackedSpan& packed, size_t packed_ofs, - hwy::ThreadPool& pool) { - fprintf(stderr, "Compressing %s (%zuM), please wait\n", name, + void operator()(MatPtrT* compressed, const char* decorated_name, + const float* HWY_RESTRICT weights) { + int num_weights = compressed->NumElements(); + int num_compressed = compressed->NumElements(); + PackedSpan packed = MakeSpan(compressed->data(), num_compressed); + fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name, num_weights / (1000 * 1000)); - Compress(weights, num_weights, work_, packed, packed_ofs, pool_); + Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, pool_); const size_t num_bytes = packed.num * sizeof(Packed); - writer_.Add(CacheKey(name), packed.ptr, num_bytes); + writer_.Add(MakeKey(decorated_name), packed.ptr, num_bytes); } void AddScales(const float* scales, size_t len) { if (len) { - writer_.Add(CacheKey("scales"), scales, len * sizeof(scales[0])); + MatPtrT scales_ptr("scales", 0, 1); + writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales, + len * sizeof(scales[0])); } } diff --git a/compression/compress.cc b/compression/compress.cc new file mode 100644 index 0000000..e858e15 --- /dev/null +++ b/compression/compress.cc @@ -0,0 +1,22 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compression/compress.h" + +namespace gcpp { + +MatPtr::~MatPtr() {} + +} // namespace gcpp diff --git a/compression/compress.h b/compression/compress.h index 275306f..e0ea0d7 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -23,7 +23,11 @@ #include #include +#include +#include #include +#include +#include #include // IWYU pragma: begin_exports @@ -32,7 +36,8 @@ #include "compression/shared.h" // IWYU pragma: end_exports #include "compression/distortion.h" -#include "hwy/base.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // BF16 #include "hwy/contrib/thread_pool/thread_pool.h" #if COMPRESS_STATS #include "hwy/stats.h" @@ -82,6 +87,376 @@ class CompressedArray { float scale_[kBlobAlign / sizeof(float)]; }; +// Yet another array class. This one is intended to be compatible with +// CompressedArray, but have both run-time sizing and compile-time constant +// size. +// It also provides easy conversion from/to a table of contents for a BlobStore +// file, and a templated (compile-time) accessor for a 2-d array of fixed inner +// dimension and type. +// The base class is intended for accessing the metadata, without needing to +// know any of the template arguments. +// It holds only a borrowed pointer to the data, but all metadata. +// It is designed to be put in a vector, and has default copy and operator=, so +// it is easy to read/write a blob_store file. +// The derived class or an external class owns the data. +class MatPtr { + public: + // Full constructor for dynamic sizing. + MatPtr(const std::string& name, const std::string& type, size_t element_size, + size_t rows, size_t cols) + : name_(name), + type_(type), + element_size_(element_size), + num_elements_(rows * cols), + rows_(rows), + cols_(cols), + ptr_(nullptr) {} + // Default constructor doesn't set anything. + MatPtr() = default; + virtual ~MatPtr(); + + // Number of hwy::uint128_t in a TOC entry. + // Note that the old-style BlobStore files Only have a list of keys and size. + // The new-style BlobStore files have an entry called "toc" that contains a + // vector of 4-tuples of + // (name, type, (num_elements, element_size), (rows, cols)). + // The listed blobs can be read directly into MatPtr from the BlobStore + // file, without needing any external knowledge of the number of elements, + // element size or type of the data. + static constexpr size_t kNumU128InTocEntry = 4; + + // Construct from a TOC entry. + MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1, + const hwy::uint128_t& key2, const hwy::uint128_t& key3) + : name_(StringFromKey(key0)), + type_(StringFromKey(key1)), + element_size_(key2.hi), + num_elements_(key2.lo), + rows_(key3.lo), + cols_(key3.hi) {} + + // Adds the contents entry to the table of contents. + void AddToToc(std::vector& toc) const { + toc.push_back(MakeKey(name_.c_str())); + toc.push_back(MakeKey(type_.c_str())); + toc.push_back({num_elements_, element_size_}); + toc.push_back({rows_, cols_}); + } + + // Compatibility interface for CompressedArray. + template + T* data() { + return HWY_RCAST_ALIGNED(T*, ptr_); + } + template + const T* data() const { + return HWY_RCAST_ALIGNED(const T*, ptr_); + } + + const void* Ptr() const { return ptr_; } + void* Ptr() { return ptr_; } + // Sets the pointer from another MatPtr. + void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; } + + // Copying allowed as the metadata is small. + MatPtr(const MatPtr& other) = default; + MatPtr& operator=(const MatPtr& other) = default; + + // Returns the name of the blob. + const std::string& Name() const { return name_; } + void SetName(const std::string& name) { name_ = name; } + + // Returns the type of the blob. + const std::string& Type() const { return type_; } + + // Returns the size of each element in bytes. + size_t ElementSize() const { return element_size_; } + + // Returns the number of elements in the array. + size_t NumElements() const { return num_elements_; } + + // Returns the number of bytes in the array. + size_t SizeBytes() const { return num_elements_ * element_size_; } + size_t CompressedSize() const { return SizeBytes(); } + + // Returns the number of rows in the 2-d array (outer dimension). + size_t Rows() const { return rows_; } + + // Returns the number of columns in the 2-d array (inner dimension). + size_t Cols() const { return cols_; } + + // Decoded elements should be multiplied by this to restore their original + // range. This is required because SfpStream can only encode a limited range + // of magnitudes. + float scale() const { return scale_; } + void set_scale(float scale) { scale_ = scale; } + + std::string LayerName(int layer) const { + std::string name = name_ + std::to_string(layer); + HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t)); + return name; + } + + // Adds the blob to the writer. + void AddToWriter(BlobWriter& writer) const { + fprintf(stderr, "Adding %s to writer\n", name_.c_str()); + writer.Add(MakeKey(name_.c_str()), ptr_, SizeBytes()); + } + + // Sets all data to zero. + void ZeroInit() { + if (ptr_ == nullptr) + HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str()); + hwy::ZeroBytes(ptr_, SizeBytes()); + } + + // Calls func on the upcasted type. Since MatPtr by design is not templated, + // here we provide a way to get to the derived type, provided that the type + // matches one of a known short-list. + template + decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args); + + protected: + // Arbitrary name for the array of preferably <= 16 characters. + std::string name_; + // Should be the result of TypeName for CallUpcasted() to work. + std::string type_; + // sizeof(T) + size_t element_size_ = 0; + // Number of elements in the array. + size_t num_elements_ = 0; // In element_size units. + // Number of rows in the 2-d array (outer dimension). + size_t rows_ = 0; + // Number of columns in the 2-d array (inner dimension). + size_t cols_ = 0; + // Scaling to apply to each element. + float scale_ = 1.0f; + // Aligned data array. This is always a borrowed pointer. It should never be + // freed. The underlying memory is owned by a subclass or some external class + // and must outlive this object. + void* ptr_ = nullptr; +}; + +// MatPtrT adds a single template argument to MatPtr for an explicit type. +// Use this class as a function argument where the type needs to be known. +// Use MatPtr where the type does not need to be known. +template +class MatPtrT : public MatPtr { + public: + using value_type = MatT; + + // Full constructor for dynamic sizing. + MatPtrT(const std::string& name, size_t rows, size_t cols) + : MatPtr(name, TypeName(), sizeof(MatT), rows, cols) {} + + // Copying allowed as the metadata is small. + MatPtrT(const MatPtr& other) : MatPtr(other) {} + MatPtrT& operator=(const MatPtr& other) { + MatPtr::operator=(other); + return *this; + } + MatPtrT(const MatPtrT& other) = default; + MatPtrT& operator=(const MatPtrT& other) = default; + + std::string CacheName(int layer = -1, char separator = ' ', + int index = -1) const { + // Already used/retired: s, S, n, 1 + const char prefix = hwy::IsSame() ? 'F' + : hwy::IsSame() ? 'B' + : hwy::IsSame() ? '$' + : hwy::IsSame() ? '2' + : '?'; + std::string name = std::string(1, prefix) + name_; + if (layer >= 0 || index >= 0) { + name += '_'; + if (layer >= 0) name += std::to_string(layer); + if (index >= 0) { + name += separator + std::to_string(index); + } + } + return name; + } + // Sets the number of elements in the array. For use when the number of + // elements is != rows * cols ONLY. + void SetNumElements(size_t num_elements) { + num_elements_ = CompressedArrayElements(num_elements); + } + + // Fast 2-d accessor for a 2-d array of fixed inner dimension and type. + template + const T& AtT(size_t row, size_t col) const { + size_t index = row * kInner + col; + HWY_DASSERT(index < num_elements_); + return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; + } + + // 2-d Accessor for a specific type but with a dynamic inner dimension. + template + const T& At(size_t row, size_t col) const { + size_t index = row * cols_ + col; + HWY_DASSERT(index < num_elements_); + return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; + } + + // 1-d Accessor for a specific type. + template + const T& At(size_t index) const { + HWY_DASSERT(index < num_elements_); + return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; + } + template + T& At(size_t index) { + return HWY_RCAST_ALIGNED(T*, ptr_)[index]; + } + + // Compatibility interface for CompressedArray. + template + T* data() { + return HWY_RCAST_ALIGNED(T*, ptr_); + } + template + const T* data() const { + return HWY_RCAST_ALIGNED(const T*, ptr_); + } + // The const accessor data_scale1() asserts (!) that the scale is 1.0f, so + // calling it means "I am sure the scale is 1 and therefore ignore the scale". + // A scale of 0 indicates that the scale has likely never been set, so is + // "implicitly 1". + const MatT* data_scale1() const { + HWY_ASSERT(scale() == 1.f); + return HWY_RCAST_ALIGNED(const MatT*, ptr_); + } +}; + +template +decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { + if (type_ == TypeName()) { + return func(dynamic_cast*>(this), + std::forward(args)...); + } else if (type_ == TypeName()) { + return func(dynamic_cast*>(this), + std::forward(args)...); + } else if (type_ == TypeName()) { + return func(dynamic_cast*>(this), + std::forward(args)...); + } else { + HWY_ABORT("Type %s unknown.", type_.c_str()); + } +} + +// MatStorageT adds the actual data storage to MatPtrT. +template +class MatStorageT : public MatPtrT { + public: + // Full constructor for dynamic sizing. + MatStorageT(const std::string& name, size_t rows, size_t cols) + : MatPtrT(name, rows, cols), + data_(hwy::AllocateAligned( + hwy::DivCeil(this->SizeBytes(), sizeof(MatT)))) { + this->ptr_ = data_.get(); + } + // Can copy the metadata, from a MatPtr, and allocate later. + MatStorageT(const MatPtr& other) : MatPtrT(other) {} + + // No copying of MatStorageT as it contains big data. + MatStorageT(const MatStorageT& other) = delete; + MatStorageT& operator=(const MatStorageT& other) = delete; + MatStorageT(MatStorageT&& other) = default; + MatStorageT& operator=(MatStorageT&& other) = default; + + // Allocate the memory and copy the pointer to the MatPtr. + // num_elements is in elements. In the default (zero) case, it is computed + // from the current num_elements_ which was set by the constructor from the + // rows and cols. + void Allocate(size_t num_elements = 0) { + if (num_elements == 0) { + num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT)); + } else { + this->num_elements_ = num_elements; + } + data_ = hwy::AllocateAligned(num_elements); + this->ptr_ = data_.get(); + } + + // Zeros the content. + void ZeroInit() { + HWY_ASSERT(data_ != nullptr); + hwy::ZeroBytes(data_.get(), this->SizeBytes()); + } + + private: + // Aligned data array. + // std::unique_ptr data_; + hwy::AlignedFreeUniquePtr data_; +}; + +// MatStorage allows heterogeneous tensors to be stored in a single vector. +using MatStorage = MatStorageT; + +// Table of contents for a blob store file. Full metadata, but not actual data. +class BlobToc { + public: + BlobToc() = default; + + // Adds all blobs to the blob writer. Note that the blobs must have unique + // names. + static void AddAllToBlobWriter(const std::vector& blobs, + BlobWriter& writer) { + std::vector toc; + for (const auto& blob : blobs) { + blob.AddToToc(toc); + blob.AddToWriter(writer); + } + writer.Add(MakeKey(kTocName), toc.data(), toc.size() * sizeof(toc[0])); + } + + // Loads the table of contents from the given reader. + BlobError LoadToc(BlobReader& reader) { + hwy::uint128_t toc_key = MakeKey(kTocName); + size_t toc_size = reader.BlobSize(toc_key); + if (toc_size != 0) { + std::vector toc(toc_size / sizeof(hwy::uint128_t)); + BlobError err = reader.ReadOne(toc_key, toc.data(), toc_size); + if (err != 0) { + fprintf(stderr, "Failed to read toc (error %d)\n", err); + return err; + } + for (size_t i = 0; i < toc.size(); i += MatPtr::kNumU128InTocEntry) { + AddToToc(MatPtr(toc[i], toc[i + 1], toc[i + 2], toc[i + 3])); + } + } + return 0; + } + + bool Empty() const { return toc_map_.empty(); } + + // Returns true if the table of contents contains the given name. + bool Contains(const std::string& name) const { + return toc_map_.find(name) != toc_map_.end(); + } + + // Returns the blob with the given name, or nullptr if not found. + const MatPtr* Get(const std::string& name) const { + auto it = toc_map_.find(name); + if (it == toc_map_.end()) return nullptr; + return &toc_[it->second]; + } + + private: + // The name of the toc in the blob store file. + static constexpr char kTocName[] = "toc"; + + // Adds the blob to the table of contents. + void AddToToc(const MatPtr& blob) { + HWY_ASSERT(!Contains(blob.Name())); + toc_map_[blob.Name()] = toc_.size(); + toc_.push_back(blob); + } + + std::unordered_map toc_map_; + std::vector toc_; +}; + #if COMPRESS_STATS class CompressStats { public: @@ -146,21 +521,6 @@ struct CompressWorkingSet { std::vector tls; }; -// Returns key for the given tensor name. Also encodes the type, so that -// changing the representation automatically invalidates prior cached files -// (the new blob name will not be found). -template -hwy::uint128_t CacheKey(const char* name) { - // Already used/retired: s, S, n, 1 - const char prefix = hwy::IsSame() ? 'F' - : hwy::IsSame() ? 'B' - : hwy::IsSame() ? '$' - : hwy::IsSame() ? '2' - : '?'; - - return MakeKey((std::string(1, prefix) + name).c_str()); -} - // Functor called for each tensor, which loads them and their scaling factors // from BlobStore. class CacheLoader { @@ -170,43 +530,82 @@ class CacheLoader { if (err_ != 0) { fprintf(stderr, "Cached compressed weights does not exist yet (code %d), " - "compressing weights and creating file: %s.\n", + "loading from file: %s.\n", err_, blob_filename.path.c_str()); } + err_ = file_toc_.LoadToc(reader_); + if (err_ != 0) { + fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_); + } } + // Returns true if there is a TOC. + bool HaveToc() const { return !file_toc_.Empty(); } + // Called for each tensor, enqueues read requests. - template - void operator()(const char* name, const float* null, - CompressedArray& compressed) { - HWY_DASSERT(null == nullptr); - - // Skip if reader_ is invalid or any load failed: we will regenerate - // everything because it's rare to update only a few tensors. - if (err_ != 0) return; - - const PackedSpan span = compressed.GetSpan(); - const size_t num_bytes = span.num * sizeof(Packed); - err_ = reader_.Enqueue(CacheKey(name), span.ptr, num_bytes); - compressed.set_scale(1.0f); - if (err_ != 0) { - fprintf(stderr, "Failed to read cache %s (error %d)\n", name, err_); + void operator()(const char* name, hwy::Span tensors) { + if (file_toc_.Empty() || file_toc_.Contains(name)) { + if (tensors[0]->NumElements() == 0) + fprintf(stderr, "Zero elements for %s\n", name); + model_toc_.push_back(tensors[0]); + file_keys_.push_back(name); } } - void LoadScales(float* scales, size_t len) { - if (0 != reader_.Enqueue(CacheKey("scales"), scales, - len * sizeof(scales[0]))) { - for (size_t i = 0; i < len; ++i) { - scales[i] = 1.0f; - } + BlobError LoadScales(float* scales, size_t len) { + for (size_t i = 0; i < len; ++i) { + scales[i] = 1.0f; } + MatPtrT scales_ptr("scales", 0, 1); + auto key = MakeKey(scales_ptr.CacheName().c_str()); + if (reader_.BlobSize(key) == 0) return 0; + return reader_.Enqueue(key, scales, len * sizeof(scales[0])); } // Returns whether all tensors are successfully loaded from cache. - bool ReadAll(hwy::ThreadPool& pool) { + bool ReadAll(hwy::ThreadPool& pool, std::vector& model_memory) { // reader_ invalid or any Enqueue failed if (err_ != 0) return false; + // Setup the model_memory. + for (int b = 0; b < model_toc_.size(); ++b) { + const std::string& file_key = file_keys_[b]; + MatPtr* blob = model_toc_[b]; + if (!file_toc_.Empty()) { + const MatPtr* toc_blob = file_toc_.Get(file_key); + if (toc_blob == nullptr) { + fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str()); + return false; + } + if (toc_blob->Rows() != blob->Rows() || + toc_blob->Cols() != blob->Cols()) { + fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str()); + return false; + } + MatStorage toc_blob_array(*toc_blob); + model_memory.push_back(std::move(toc_blob_array)); + } else { + model_memory.emplace_back(*blob); + model_memory.back().SetName(file_key); + } + } + // Allocate in parallel using the pool. + pool.Run(0, model_memory.size(), + [this, &model_memory](uint64_t task, size_t /*thread*/) { + model_memory[task].Allocate(); + model_toc_[task]->SetPtr(model_memory[task]); + }); + // Enqueue the read requests. + for (auto& blob : model_memory) { + err_ = reader_.Enqueue(MakeKey(blob.Name().c_str()), blob.data(), + blob.SizeBytes()); + if (err_ != 0) { + fprintf(stderr, + "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", + blob.Name().c_str(), err_, blob.Rows(), blob.Cols(), + blob.ElementSize()); + return false; + } + } err_ = reader_.ReadAll(pool); if (err_ != 0) { @@ -220,6 +619,13 @@ class CacheLoader { private: BlobReader reader_; BlobError err_ = 0; + // Table of contents from the file, if present. + BlobToc file_toc_; + // Table of contents from the model. Pointers to original MatPtrT so the + // data pointers can be updated. + std::vector model_toc_; + // Mangled names of the tensors in model_toc_ for reading from the file. + std::vector file_keys_; }; } // namespace gcpp diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 41b46be..1ae400e 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -36,155 +36,23 @@ #include #include #include // NOLINT +#include +#include "compression/compress.h" #include "compression/io.h" // Path -#include "compression/shared.h" -#include "compression/weights_raw.h" -#include "gemma/common.h" // Model +#include "gemma/common.h" // Model #include "gemma/weights.h" #include "util/allocator.h" #include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" namespace gcpp { -// Setting this to true disables fread() calls that read the model file. -constexpr bool kDryRunFread = false; - namespace { -#define READ_WEIGHTS(name) \ - do { \ - do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \ - } while (0) - -#define SCALE_WEIGHTS(name) \ - do { \ - if (ok && !kDryRunFread && scale_for_compression) { \ - weights->scales[scale_pos++] = \ - ScaleWeights(layer_view->name.data(), layer_view->name.size()); \ - } \ - } while (0) - -template -struct LoadRawWeightsT { - ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool, - bool scale_for_compression) const { - PROFILER_ZONE("Startup.LoadWeights"); - if (!checkpoint.Exists()) { - HWY_ABORT("The model weights file '%s' does not exist.", - checkpoint.path.c_str()); - } - - ByteStorageT weights_u8 = AllocateWeightsF()(pool); - auto* weights = reinterpret_cast*>(weights_u8.get()); - - size_t scale_pos = 0; - FILE* fptr; - if constexpr (kDryRunFread) { - fprintf(stderr, "Dry-Run, not reading model-file.\n"); - } else { - fptr = fopen(checkpoint.path.c_str(), "rb"); - if (fptr == nullptr) { - HWY_ABORT("Failed to open model file %s - does it exist?", - checkpoint.path.c_str()); - } - } - bool ok = true; - uint64_t total_size = 0; - auto do_fread = [&](void* var, int layer, const char* name, size_t size) { - if (layer == -1) { - fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name); - } else { - fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer, - size, name); - } - if constexpr (!kDryRunFread) { - ok &= 1 == fread(var, size, 1, fptr); - total_size += size; - } - }; - do_fread(&(weights->embedder_input_embedding), -1, - "embedder_input_embedding", - sizeof(weights->embedder_input_embedding)); - do_fread(&(weights->final_norm_scale), -1, "final_norm_scale", - sizeof(weights->final_norm_scale)); - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - auto type = TConfig::kLayerConfig[layer]; - LayerF* layer_view = weights->GetLayer(layer); - - // Make sure we don't have uninitialized memory. - hwy::ZeroBytes(layer_view, sizeof(*layer_view)); - if (type == LayerAttentionType::kGemma) { - READ_WEIGHTS(attn_vec_einsum_w); - READ_WEIGHTS(qkv_einsum_w); - SCALE_WEIGHTS(attn_vec_einsum_w); - SCALE_WEIGHTS(qkv_einsum_w); - } else { - READ_WEIGHTS(griffin.linear_x_w); - READ_WEIGHTS(griffin.linear_x_biases); - READ_WEIGHTS(griffin.linear_y_w); - READ_WEIGHTS(griffin.linear_y_biases); - READ_WEIGHTS(griffin.linear_out_w); - READ_WEIGHTS(griffin.linear_out_biases); - READ_WEIGHTS(griffin.conv_w); - READ_WEIGHTS(griffin.conv_biases); - READ_WEIGHTS(griffin.gate_w); - READ_WEIGHTS(griffin.gate_biases); - READ_WEIGHTS(griffin.a); - SCALE_WEIGHTS(griffin.linear_x_w); - SCALE_WEIGHTS(griffin.linear_y_w); - SCALE_WEIGHTS(griffin.linear_out_w); - SCALE_WEIGHTS(griffin.gate_w); - } - READ_WEIGHTS(gating_einsum_w); - READ_WEIGHTS(linear_w); - SCALE_WEIGHTS(gating_einsum_w); - SCALE_WEIGHTS(linear_w); - READ_WEIGHTS(pre_attention_norm_scale); - READ_WEIGHTS(pre_ffw_norm_scale); - if (TConfig::kPostNorm == PostNormType::Scale) { - READ_WEIGHTS(post_attention_norm_scale); - READ_WEIGHTS(post_ffw_norm_scale); - } - if (TConfig::kFFBiases) { - READ_WEIGHTS(ffw_gating_biases); - READ_WEIGHTS(ffw_output_biases); - } - if (TConfig::kSoftmaxAttnOutputBiases && - type == LayerAttentionType::kGemma) { - READ_WEIGHTS(attention_output_biases); - } - } - if (!ok) { - HWY_ABORT( - "Failed to read from %s - might be a directory, or too small? " - "expected size: %d kB", - checkpoint.path.c_str(), static_cast(total_size >> 10)); - } - if (!kDryRunFread) { - HWY_ASSERT(0 == fclose(fptr)); - if (scale_for_compression) { - HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); - } - } - return weights_u8; - } -}; - -#undef READ_WEIGHTS -#undef SCALE_WEIGHTS } // namespace -ByteStorageT LoadRawWeights(const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool, - bool scale_for_compression) { - return CallForModelAndWeight( - model_type, weight_type, weights, pool, scale_for_compression); -} - struct Args : public ArgsBase { static constexpr size_t kDefaultNumThreads = ~size_t{0}; @@ -282,7 +150,7 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template +template void CompressWeights(const Path& weights_path, const Path& compressed_weights_path, Model model_type, Type weight_type, hwy::ThreadPool& pool) { @@ -290,26 +158,53 @@ void CompressWeights(const Path& weights_path, HWY_ABORT("The model weights file '%s' does not exist.", weights_path.path.c_str()); } + printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), + compressed_weights_path.path.c_str()); + using CConfig = Configs::c; + using UCConfig = Configs::uc; // Allocate compressed weights. - using CWeights = CompressedWeights; - ByteStorageT c_weights_u8 = AllocateSizeof(); + using CWeights = CompressedWeights; + ByteStorageT c_weights_u8 = AllocateCompressedWeights()(pool); CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); - new (&c_weights->c_layer_ptrs) CompressedLayerPointers(pool); - // Get weights, compress, and store. - const bool scale_for_compression = TConfig::kNumTensorScales > 0; - const ByteStorageT weights_u8 = gcpp::LoadRawWeights( - weights_path, model_type, weight_type, pool, scale_for_compression); - WeightsF* weights = - reinterpret_cast*>(weights_u8.get()); + // Allocate uncompressed weights. + using UCWeights = CompressedWeights; + ByteStorageT uc_weights_u8 = AllocateCompressedWeights()(pool); + UCWeights* uc_weights = reinterpret_cast(uc_weights_u8.get()); + + // Get uncompressed weights, compress, and store. + FILE* fptr = fopen(weights_path.path.c_str(), "rb"); + if (fptr == nullptr) { + HWY_ABORT("Failed to open model file %s - does it exist?", + weights_path.path.c_str()); + } + bool ok = true; + uint64_t total_size = 0; + CompressedWeights::ForEachTensor( + {uc_weights}, ForEachType::kLoadNoToc, + [&](const char* name, hwy::Span tensors) { + fprintf(stderr, "Loading Parameters (size %zu): %s\n", + tensors[0]->SizeBytes(), name); + ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); + total_size += tensors[0]->SizeBytes(); + }); + const bool scale_for_compression = UCConfig::kNumTensorScales > 0; + std::vector scales; + if (scale_for_compression) { + uc_weights->GetOrApplyScales(scales); + } Compressor compressor(pool); - ForEachTensor>(weights, *c_weights, compressor); - compressor.AddScales(weights->scales.data(), weights->scales.size()); + CompressedWeights::ForEachTensor( + {reinterpret_cast*>(uc_weights), c_weights}, + ForEachType::kLoadNoToc, + [&compressor](const char* name, hwy::Span tensors) { + tensors[1]->CallUpcasted( + compressor, name, + reinterpret_cast(tensors[0]->Ptr())); + }); + compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0])); compressor.WriteAll(pool, compressed_weights_path); - - weights->layer_ptrs.~LayerPointers(); - c_weights->c_layer_ptrs.~CompressedLayerPointers(); } } // namespace HWY_NAMESPACE diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index a6dcd11..a9d3894 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -53,34 +53,35 @@ namespace HWY_NAMESPACE { class SbsWriterImpl : public WriterInterface { template - hwy::AlignedFreeUniquePtr AllocateAndCompress( - const std::string& name, absl::Span weights) { + void AllocateAndCompress(const std::string& name, + absl::Span weights) { const size_t num_packed = CompressedArrayElements(weights.size()); - auto packed = hwy::AllocateAligned(num_packed); - PackedSpan span = MakeSpan(packed.get(), num_packed); - compressor_.Insert(name.c_str(), weights.data(), weights.size(), - working_set_, span, /*packed_ofs=*/0, pool_); - return packed; + MatPtrT storage(name, 1, num_packed); + model_memory_.push_back(storage); + model_memory_.back().Allocate(); + storage.SetPtr(model_memory_.back()); + std::string decorated_name = storage.CacheName(); + compressor_(&storage, decorated_name.c_str(), weights.data()); } public: SbsWriterImpl() : pool_(0), compressor_(pool_) {} void Insert(std::string name, absl::Span weights) override { - sfp_streams_.push_back(AllocateAndCompress(name, weights)); + AllocateAndCompress(name, weights); } void InsertNUQ(std::string name, absl::Span weights) override { - nuq_streams_.push_back(AllocateAndCompress(name, weights)); + AllocateAndCompress(name, weights); } void InsertBfloat16(std::string name, absl::Span weights) override { - bf16_streams_.push_back(AllocateAndCompress(name, weights)); + AllocateAndCompress(name, weights); } void InsertFloat(std::string name, absl::Span weights) override { - f32_streams_.push_back(AllocateAndCompress(name, weights)); + AllocateAndCompress(name, weights); } void AddScales(const std::vector& scales) override { @@ -96,10 +97,7 @@ class SbsWriterImpl : public WriterInterface { hwy::ThreadPool pool_; Compressor compressor_; CompressWorkingSet working_set_; - std::vector> sfp_streams_; - std::vector> nuq_streams_; - std::vector> bf16_streams_; - std::vector> f32_streams_; + std::vector model_memory_; std::vector scales_; }; diff --git a/compression/shared.h b/compression/shared.h index 166cd29..b79a067 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -22,6 +22,7 @@ #include #include +#include #include #include "hwy/aligned_allocator.h" @@ -184,6 +185,12 @@ const char* TypeName() { return "sfp"; } else if constexpr (hwy::IsSame()) { return "nuq"; + } else if constexpr (hwy::IsSame()) { + return "f64"; + } else if constexpr (hwy::IsSame>()) { + return "c64"; + } else if constexpr (hwy::IsSame()) { + return "u128"; } else { HWY_DASSERT(false); return "unknown"; diff --git a/compression/weights_raw.h b/compression/weights_raw.h deleted file mode 100644 index c6e6a73..0000000 --- a/compression/weights_raw.h +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_WEIGHTS_RAW_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_WEIGHTS_RAW_H_ - -// Historical note: this was the original f32-only simple on-disk format -// created by convert_weights.py. BlobStore is now the preferred on-disk -// format, and we load that into CompressedWeights. -// -// NOTE: this file should only be used by compress_weights. It is currently -// also referenced by backprop because it supports T = std::complex, and -// CompressedWeights might not yet. - -#include - -#include "gemma/configs.h" -#include "util/allocator.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -template -struct Layer { - Layer() {} - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; - static constexpr size_t kQKVEinsumWSize = - (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; - // 2x for (gelu gating vector, gated vector) - static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; - static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr PostNormType kPostNorm = TConfig::kPostNorm; - static constexpr size_t kAOBiasDim = - TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; - static constexpr size_t kGriffinDim = - TConfig::kGriffinLayers > 0 ? kModelDim : 0; - - union { - struct { - std::array attn_vec_einsum_w; - std::array qkv_einsum_w; - std::array attention_output_biases; - }; - - struct { - std::array linear_x_w; - std::array linear_x_biases; - std::array linear_y_w; - std::array linear_y_biases; - std::array linear_out_w; - std::array linear_out_biases; - std::array conv_w; - std::array conv_biases; - std::array gate_w; - std::array gate_biases; - std::array a; - } griffin; - }; - - std::array gating_einsum_w; - std::array linear_w; - std::array pre_attention_norm_scale; - std::array pre_ffw_norm_scale; - std::array - post_attention_norm_scale; - std::array - post_ffw_norm_scale; - - std::array ffw_gating_biases; - std::array ffw_output_biases; -}; - -template -using LayerF = Layer; - -// Array instead of single large allocation for parallel mem init. Split out of -// Weights so that only these pointers are initialized. -template -struct LayerPointers { - explicit LayerPointers(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { - this->layers[task] = hwy::AllocateAligned>(1); - }); - } - - using TLayer = Layer; - std::array, TConfig::kLayers> layers; -}; - -template -struct Weights { - // No ctor/dtor, allocated via AllocateAligned. - - std::array - embedder_input_embedding; - - std::array final_norm_scale; - - LayerPointers layer_ptrs; - - std::array scales; - - const Layer* GetLayer(size_t layer) const { - return layer_ptrs.layers[layer].get(); - } - Layer* GetLayer(size_t layer) { - return layer_ptrs.layers[layer].get(); - } -}; - -template -using WeightsF = Weights; - -// TODO: can we use TConfig::Weight instead of T? -template -struct AllocateWeights { - ByteStorageT operator()(hwy::ThreadPool& pool) const { - using TWeights = Weights; - ByteStorageT weights_u8 = AllocateSizeof(); - TWeights* weights = reinterpret_cast(weights_u8.get()); - new (&weights->layer_ptrs) LayerPointers(pool); - return weights_u8; - } -}; - -template -struct AllocateWeightsF { - ByteStorageT operator()(hwy::ThreadPool& pool) const { - return AllocateWeights()(pool); - } -}; - -// TODO: make a member of Weights. -template -struct ZeroInitWeights { - void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { - Weights& w = - *reinterpret_cast*>(weights.get()); - hwy::ZeroBytes(&w.embedder_input_embedding, - sizeof(w.embedder_input_embedding)); - hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); - for (int i = 0; i < TConfig::kLayers; ++i) { - hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); - } - } -}; - -template -struct ZeroInitWeightsF { - void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { - ZeroInitWeights()(weights, pool); - } -}; - -template -struct CopyWeights { -void operator()(Weights& dst, - const Weights& src) const { - hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, - sizeof(src.embedder_input_embedding)); - hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, - sizeof(src.final_norm_scale)); - for (int i = 0; i < TConfig::kLayers; ++i) { - hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), - sizeof(*dst.GetLayer(i))); - } - } -}; - -template -void RandInit(std::array& x, T stddev, std::mt19937& gen) { - std::normal_distribution dist(0.0, stddev); - for (size_t i = 0; i < kLen; ++i) { - x[i] = dist(gen); - } -} - -// TODO: make a member of Layer. -template -void RandInit(Layer& w, T stddev, std::mt19937& gen) { - RandInit(w.pre_attention_norm_scale, stddev, gen); - RandInit(w.attn_vec_einsum_w, stddev, gen); - RandInit(w.qkv_einsum_w, stddev, gen); - RandInit(w.pre_ffw_norm_scale, stddev, gen); - RandInit(w.gating_einsum_w, stddev, gen); - RandInit(w.linear_w, stddev, gen); -} - -template -void RandInit(Weights& w, T stddev, std::mt19937& gen) { - static constexpr size_t kLayers = TConfig::kLayers; - RandInit(w.embedder_input_embedding, stddev, gen); - RandInit(w.final_norm_scale, stddev, gen); - for (size_t i = 0; i < kLayers; ++i) { - RandInit(*w.GetLayer(i), stddev, gen); - } -} - -// Owns weights and provides access to TConfig. -template -class WeightsWrapper { - public: - WeightsWrapper() - : pool_(0), - data_(AllocateWeights()(pool_)), - weights_(reinterpret_cast*>(data_.get())) {} - - ~WeightsWrapper() { - get().layer_ptrs.~LayerPointers(); - } - - const Weights& get() const { return *weights_; } - Weights& get() { return *weights_; } - void clear() { ZeroInitWeights()(data_, pool_); } - void copy(const WeightsWrapper& other) { - CopyWeights()(get(), other.get()); - } - - private: - hwy::ThreadPool pool_; - ByteStorageT data_; - Weights* weights_; -}; - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_WEIGHTS_RAW_H_ diff --git a/gemma/common.h b/gemma/common.h index 2960b7c..aa5bc52 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -146,50 +146,52 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, // Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float), // calls FUNC> where ConfigT is chosen via MODEL enum. -#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \ - switch (MODEL) { \ - case Model::GEMMA_TINY: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GEMMA_2B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GEMMA_7B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GRIFFIN_2B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GEMMA2_2B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GEMMA2_9B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GEMMA2_27B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::PALIGEMMA_224: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>)\ - ARGS; \ - break; \ - } \ - default: \ - HWY_ABORT("Model type %d unknown.", static_cast(MODEL)); \ +#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \ + switch (MODEL) { \ + case Model::GEMMA_TINY: { \ + using CP = ConfigPair, ConfigGemmaTiny>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::GEMMA_2B: { \ + using CP = ConfigPair, ConfigGemma2B>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::GEMMA_7B: { \ + using CP = ConfigPair, ConfigGemma7B>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::GRIFFIN_2B: { \ + using CP = ConfigPair, ConfigGriffin2B>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::GEMMA2_2B: { \ + using CP = ConfigPair, ConfigGemma2_2B>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::GEMMA2_9B: { \ + using CP = ConfigPair, ConfigGemma2_9B>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::GEMMA2_27B: { \ + using CP = \ + ConfigPair, ConfigGemma2_27B>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + case Model::PALIGEMMA_224: { \ + using CP = ConfigPair, \ + ConfigPaliGemma_224>; \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC) ARGS; \ + break; \ + } \ + default: \ + HWY_ABORT("Model type %d unknown.", static_cast(MODEL)); \ } // Like CallForModelAndWeight, but for SIMD function templates. This is a macro diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index d2d9cd0..9b9a0c4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -244,11 +244,19 @@ class GemmaAttention { const auto pre_att_rms_out = ConstMat(activations_.pre_att_rms_out.All(), kModelDim); - MatMul( - num_interleaved, pre_att_rms_out, - ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim), - layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env, - MutableMat(activations_.q.All(), kHeads * kQStride)); + const auto w_q1 = + layer_weights_.qkv_einsum_w.data() == nullptr + ? ConstMat(layer_weights_.qkv_einsum_w1.data(), kModelDim) + : ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim); + const auto w_q2 = + layer_weights_.qkv_einsum_w.data() == nullptr + ? ConstMat(layer_weights_.qkv_einsum_w2.data(), kModelDim) + : ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim, + kHeads * kQKVDim * kModelDim); + MatMul(num_interleaved, pre_att_rms_out, w_q1, + layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, + activations_.env, + MutableMat(activations_.q.All(), kHeads * kQStride)); if constexpr (kIsMHA) { static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved"); @@ -263,9 +271,7 @@ class GemmaAttention { // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; MatMul( - num_tokens_, pre_att_rms_out, - ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim, - kHeads * kQKVDim * kModelDim), + num_tokens_, pre_att_rms_out, w_q2, layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env, MutableMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize)); @@ -283,9 +289,14 @@ class GemmaAttention { cache_pos * kCachePosSize + layer_ * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). - MatVec( - layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, kv, - pool_); + if (layer_weights_.qkv_einsum_w.data() == nullptr) { + MatVec( + layer_weights_.qkv_einsum_w2, 0, x, kv, pool_); + } else { + MatVec( + layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, + kv, pool_); + } } } } @@ -692,10 +703,16 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, output_bias = layer_weights->ffw_output_biases.data_scale1(); } if constexpr (!kIsVit) { - w1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim); - w2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim, kModelDim, - kModelDim * kFFHiddenDim); - scale = layer_weights->gating_einsum_w.scale(); + w1 = layer_weights->gating_einsum_w.data() == nullptr + ? ConstMat(layer_weights->gating_einsum_w1.data(), kModelDim) + : ConstMat(layer_weights->gating_einsum_w.data(), kModelDim); + w2 = layer_weights->gating_einsum_w.data() == nullptr + ? ConstMat(layer_weights->gating_einsum_w2.data(), kModelDim) + : ConstMat(layer_weights->gating_einsum_w.data(), kModelDim, + kModelDim, kModelDim * kFFHiddenDim); + scale = layer_weights->gating_einsum_w.data() == nullptr + ? layer_weights->gating_einsum_w1.scale() + : layer_weights->gating_einsum_w.scale(); w_output = ConstMat(layer_weights->linear_w.data(), kFFHiddenDim); output_scale = layer_weights->linear_w.scale(); } else { @@ -712,7 +729,7 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, hidden_activations); if constexpr (!kIsVit) { MatMul(num_interleaved, x, w2, scale, bias2, activations.env, - multiplier); + multiplier); } // Activation (Gelu) and maybe multiply by gate. Store activations in act. diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 94cffc1..722adcb 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -52,8 +52,6 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, } Gemma::~Gemma() { - CallForModelAndWeight(info_.model, info_.weight, - weights_u8_); } // There are >100 instantiations of the inference code. To reduce compile time, diff --git a/gemma/weights.cc b/gemma/weights.cc index 77f0628..955c4d6 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -15,15 +15,15 @@ #include "gemma/weights.h" -#include - +#include #include +#include #include "compression/compress.h" #include "compression/io.h" // Path #include "gemma/common.h" -#include "gemma/configs.h" #include "util/allocator.h" +#include "hwy/aligned_allocator.h" #include "hwy/base.h" // HWY_ABORT #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" @@ -47,32 +47,23 @@ struct LoadCompressedWeightsT { CWeights* c_weights = reinterpret_cast(c_weights_u8.get()); new (c_weights) CWeights(pool); - std::array scales; CacheLoader loader(weights); - ForEachTensor(nullptr, *c_weights, loader); - loader.LoadScales(scales.data(), scales.size()); - if (!loader.ReadAll(pool)) { + ForEachType fet = + loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc; + CWeights::ForEachTensor( + {c_weights}, fet, + [&loader](const char* name, hwy::Span tensors) { + loader(name, tensors); + }); + std::vector scales(TConfig::kNumTensorScales); + if (TConfig::kNumTensorScales > 0) { + loader.LoadScales(scales.data(), scales.size()); + } + if (!loader.ReadAll(pool, c_weights->model_storage)) { HWY_ABORT("Failed to load model weights."); } if (TConfig::kNumTensorScales > 0) { - size_t scale_pos = 0; - for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - auto type = TConfig::kLayerConfig[layer_idx]; - const size_t idx = static_cast(layer_idx); - CompressedLayer* layer_weights = c_weights->GetLayer(idx); - if (type == LayerAttentionType::kGemma) { - layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]); - layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]); - } else { - layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]); - layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]); - layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]); - layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]); - } - layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]); - layer_weights->linear_w.set_scale(scales[scale_pos++]); - } - HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); + c_weights->GetOrApplyScales(scales); } { PROFILER_ZONE("Startup.Reshape"); @@ -102,13 +93,13 @@ void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) { class WeightLogger { public: - template - void operator()(const char* name, const CompressedArray& tensor) { + void operator()(const char* name, hwy::Span tensors) { + const MatPtr& tensor = *tensors[0]; if (tensor.scale() != 1.0f) { printf("[scale=%f] ", tensor.scale()); } - LogVec(name, tensor.data(), N); - total_weights += N; + LogVec(name, tensor.data(), tensor.NumElements()); + total_weights += tensor.NumElements(); } size_t total_weights = 0; }; @@ -116,10 +107,11 @@ class WeightLogger { template struct LogWeightStatsT { void operator()(const ByteStorageT& weights_u8) const { - const auto& weights = + auto& weights = *reinterpret_cast*>(weights_u8.get()); WeightLogger logger; - ForEachTensor1(logger, weights); + CompressedWeights::ForEachTensor( + {&weights}, ForEachType::kIgnoreNulls, logger); printf("%-20s %12zu\n", "Total", logger.total_weights); } }; diff --git a/gemma/weights.h b/gemma/weights.h index b24b0c6..0c97253 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -18,7 +18,15 @@ #include +#include +#include +#include +#include +#include +#include + #include "compression/compress.h" +#include "compression/shared.h" #include "gemma/common.h" #include "gemma/configs.h" #include "util/allocator.h" @@ -28,16 +36,82 @@ namespace gcpp { +// Different tensors need to appear in a ForEachTensor, according to what is +// happening. +enum class ForEachType { + // Under normal circumstances, when not initializing or loading, we can + // include all tensors and ignore the null ones. + kIgnoreNulls, + // If there is a table of contents, we can include all tensors. + kLoadWithToc, + // There is no table of contents, so we have to be careful to only include + // tensors that are actually present. + kLoadNoToc, + // We need to initialize all tensors needed when there is no table of + // contents. This differs from kLoadNoToc in that we need to include any + // tensor that is allocated but not loaded directly from file. + kInitNoToc, +}; + template struct CompressedLayer { - // No ctor/dtor, allocated via AllocateAligned. + // Large data is constructed separately. + CompressedLayer() + : attn_vec_einsum_w("att_ein", kModelDim, kHeads * kQKVDim), + qkv_einsum_w("qkv_ein", (kHeads + 2 * kKVHeads) * kQKVDim, kModelDim), + qkv_einsum_w1("qkv1_w", kHeads * kQKVDim, kModelDim), + qkv_einsum_w2("qkv2_w", 2 * kKVHeads * kQKVDim, kModelDim), + attention_output_biases("attn_ob", 1, kAOBiasDim), + griffin({.linear_x_w = {"gr_lin_x_w", kGriffinDim, kGriffinDim}, + .linear_x_biases = {"gr_lin_x_b", 1, kGriffinDim}, + .linear_y_w = {"gr_lin_y_w", kGriffinDim, kGriffinDim}, + .linear_y_biases = {"gr_lin_y_b", 1, kGriffinDim}, + .linear_out_w = {"gr_lin_out_w", kGriffinDim, kGriffinDim}, + .linear_out_biases = {"gr_lin_out_b", 1, kGriffinDim}, + .conv_w = {"gr_conv_w", kConv1dWidth, kGriffinDim}, + .conv_biases = {"gr_conv_b", 1, kGriffinDim}, + .gate_w = {"gr_gate_w", 2 * kGriffinDim, kGriffinDim / kHeads}, + .gate_biases = {"gr_gate_b", 1, kGriffinDim * 2}, + .a = {"gr_a", 1, kGriffinDim}}), + // MultiHeadDotProductAttention. + vit({.attn_out_w = {"attn_out_w", kHeads * kQKVDim, kModelDim}, + .attn_out_b = {"attn_out_b", 1, kModelDim}, + .qkv_einsum_w = {"qkv_ein_w", (kHeads + 2 * kKVHeads) * kQKVDim, + kModelDim}, + .qkv_einsum_b = {"qkv_ein_b", (kHeads + 2 * kKVHeads), kQKVDim}, + .linear_0_w = {"linear_0_w", kModelDim, kFFHiddenDim}, + .linear_0_b = {"linear_0_b", 1, kFFHiddenDim}, + .linear_1_w = {"linear_1_w", kFFHiddenDim, kModelDim}, + .linear_1_b = {"linear_1_b", 1, kModelDim}, + .layer_norm_0_bias = {"ln_0_bias", 1, kModelDim}, + .layer_norm_0_scale = {"ln_0_scale", 1, kModelDim}, + .layer_norm_1_bias = {"ln_1_bias", 1, kModelDim}, + .layer_norm_1_scale = {"ln_1_scale", 1, kModelDim}}), + gating_einsum_w("gating_ein", 2 * kFFHiddenDim, kModelDim), + gating_einsum_w1("gating1_w", kFFHiddenDim, kModelDim), + gating_einsum_w2("gating2_w", kFFHiddenDim, kModelDim), + linear_w("linear_w", kModelDim, kFFHiddenDim), + pre_attention_norm_scale("pre_att_ns", 1, kModelDim), + pre_ffw_norm_scale("pre_ff_ns", 1, kModelDim), + post_attention_norm_scale( + "post_att_ns", 1, kPostNorm == PostNormType::Scale ? kModelDim : 0), + post_ffw_norm_scale("post_ff_ns", 1, + kPostNorm == PostNormType::Scale ? kModelDim : 0), + ffw_gating_biases("ffw_gat_b", 1, kFFBiases ? 2 * kFFHiddenDim : 0), + ffw_output_biases("ffw_out_b", 1, kFFBiases ? kModelDim : 0), + att_weights("att_w", kModelDim, kHeads * kQKVDim) + {} + ~CompressedLayer() = default; using Weight = typename TConfig::Weight; // If weights are f32, also f32; otherwise at least bf16. Useful for ops that // do not yet support smaller compressed types, or require at least bf16. When // weights are f32, we also want such tensors to be f32. - using WeightF32OrBF16 = - hwy::If(), float, hwy::bfloat16_t>; + // If weights are complex, this is also complex. + using WeightF32OrBF16 = hwy::If< + hwy::IsSame>(), std::complex, + hwy::If(), double, + hwy::If(), float, hwy::bfloat16_t>>>; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; @@ -58,69 +132,75 @@ struct CompressedLayer { static constexpr size_t kGriffinDim = TConfig::kGriffinLayers > 0 ? kModelDim : 0; - template - using ArrayT = CompressedArray; + template + using ArrayT = MatPtrT; - union { - struct { - ArrayT attn_vec_einsum_w; - ArrayT qkv_einsum_w; - ArrayT attention_output_biases; - }; + ArrayT attn_vec_einsum_w; + // qkv_einsum_w holds 2 different matrices, which may be separated out. + // On loading, which is used depends on what is in the file. + // At inference, the one with a non-null ptr is used. + ArrayT qkv_einsum_w; + ArrayT qkv_einsum_w1; + ArrayT qkv_einsum_w2; + ArrayT attention_output_biases; - struct { - ArrayT linear_x_w; - ArrayT linear_x_biases; - ArrayT linear_y_w; - ArrayT linear_y_biases; - ArrayT linear_out_w; - ArrayT linear_out_biases; - ArrayT conv_w; - ArrayT conv_biases; - ArrayT gate_w; - ArrayT gate_biases; - ArrayT a; - } griffin; + struct { + ArrayT linear_x_w; + ArrayT linear_x_biases; + ArrayT linear_y_w; + ArrayT linear_y_biases; + ArrayT linear_out_w; + ArrayT linear_out_biases; + ArrayT conv_w; + ArrayT conv_biases; + ArrayT gate_w; + ArrayT gate_biases; + ArrayT a; + } griffin; - struct { - // MultiHeadDotProductAttention. - ArrayT attn_out_w; - ArrayT attn_out_b; - ArrayT qkv_einsum_w; - ArrayT qkv_einsum_b; - // MlpBlock. - ArrayT linear_0_w; - ArrayT linear_0_b; - ArrayT linear_1_w; - ArrayT linear_1_b; - // LayerNorm. - ArrayT layer_norm_0_bias; - ArrayT layer_norm_0_scale; - ArrayT layer_norm_1_bias; - ArrayT layer_norm_1_scale; - } vit; - }; + struct { + // MultiHeadDotProductAttention. + ArrayT attn_out_w; + ArrayT attn_out_b; + ArrayT qkv_einsum_w; + ArrayT qkv_einsum_b; + // MlpBlock. + ArrayT linear_0_w; + ArrayT linear_0_b; + ArrayT linear_1_w; + ArrayT linear_1_b; + // LayerNorm. + ArrayT layer_norm_0_bias; + ArrayT layer_norm_0_scale; + ArrayT layer_norm_1_bias; + ArrayT layer_norm_1_scale; + } vit; - ArrayT gating_einsum_w; - ArrayT linear_w; + // gating_einsum_w holds 2 different matrices, which may be separated out. + // On loading, which is used depends on what is in the file. + // At inference, the one with a non-null ptr is used. + ArrayT gating_einsum_w; + ArrayT gating_einsum_w1; + ArrayT gating_einsum_w2; + ArrayT linear_w; // We don't yet have an RMSNorm that accepts all Weight. - ArrayT pre_attention_norm_scale; - ArrayT pre_ffw_norm_scale; - ArrayT - post_attention_norm_scale; - ArrayT - post_ffw_norm_scale; + ArrayT pre_attention_norm_scale; + ArrayT pre_ffw_norm_scale; + ArrayT post_attention_norm_scale; + ArrayT post_ffw_norm_scale; - ArrayT ffw_gating_biases; - ArrayT ffw_output_biases; + ArrayT ffw_gating_biases; + ArrayT ffw_output_biases; // Reshaped attention; not loaded from disk via ForEachTensor. - ArrayT att_weights; + ArrayT att_weights; // Initializes att_weights from attn_vec_einsum_w, hence this must be called // after loading weights via ForEachTensor. // TODO: update compression/convert_weights to bake this in. - void Reshape() { + void Reshape(MatStorage& storage) { + if (attn_vec_einsum_w.data() == nullptr) return; + constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kQKVDim = TConfig::kQKVDim; @@ -129,6 +209,8 @@ struct CompressedLayer { static_assert(!hwy::IsSame()); // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. + storage.Allocate(); + att_weights.SetPtr(storage); for (size_t m = 0; m < kModelDim; ++m) { Weight* HWY_RESTRICT out_row = att_weights.data() + m * kHeads * kQKVDim; for (size_t h = 0; h < kHeads; ++h) { @@ -139,118 +221,291 @@ struct CompressedLayer { } att_weights.set_scale(attn_vec_einsum_w.scale()); } -}; -// Array instead of single large allocation for parallel mem init. Split out -// of CompressedWeights so that only these pointers are initialized, not the -// CompressedArray. -template -struct CompressedLayerPointers { - explicit CompressedLayerPointers(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { - this->c_layers[task] = hwy::AllocateAligned>(1); - }); - if constexpr (TConfig::VitConfig::kLayers > 0) { - pool.Run(0, TConfig::VitConfig::kLayers, - [this](uint64_t task, size_t /*thread*/) { - this->c_vit_layers[task] = hwy::AllocateAligned< - CompressedLayer>(1); - }); +// Used by ForEachTensor for per-layer tensors. +#define GEMMA_CALL_FUNC(member) \ + { \ + for (int i = 0; i < ptrs.size(); ++i) { \ + tensors[i] = &ptrs[i]->member; \ + } \ + if (tensors[0]->Ptr() != nullptr || fet != ForEachType::kIgnoreNulls) { \ + func(ptrs[0]->member.CacheName(layer_idx, sep, sep_index).c_str(), \ + hwy::Span(tensors, ptrs.size())); \ + } \ + } + + template + static void ForEachTensor(const std::vector*>& ptrs, + int layer_idx, ForEachType fet, Func func, + char sep = ' ', int sep_index = -1) { + MatPtr* tensors[ptrs.size()]; + auto type = TConfig::kLayerConfig[layer_idx]; + if (type == LayerAttentionType::kVit) { + // MHA. + GEMMA_CALL_FUNC(vit.attn_out_w); + GEMMA_CALL_FUNC(vit.attn_out_b); + GEMMA_CALL_FUNC(vit.qkv_einsum_w); + GEMMA_CALL_FUNC(vit.qkv_einsum_b); + // MlpBlock. + GEMMA_CALL_FUNC(vit.linear_0_w); + GEMMA_CALL_FUNC(vit.linear_0_b); + GEMMA_CALL_FUNC(vit.linear_1_w); + GEMMA_CALL_FUNC(vit.linear_1_b); + // LayerNorm. + GEMMA_CALL_FUNC(vit.layer_norm_0_bias); + GEMMA_CALL_FUNC(vit.layer_norm_0_scale); + GEMMA_CALL_FUNC(vit.layer_norm_1_bias); + GEMMA_CALL_FUNC(vit.layer_norm_1_scale); + return; + } + if (type == LayerAttentionType::kGemma) { + if (fet != ForEachType::kLoadNoToc) { + GEMMA_CALL_FUNC(att_weights); + } + if (fet == ForEachType::kInitNoToc || fet == ForEachType::kLoadNoToc || + fet == ForEachType::kIgnoreNulls) { + GEMMA_CALL_FUNC(attn_vec_einsum_w); + } + GEMMA_CALL_FUNC(qkv_einsum_w); + if (fet == ForEachType::kIgnoreNulls || + fet == ForEachType::kLoadWithToc) { + // The unwanted ones will be null or not in the toc. + GEMMA_CALL_FUNC(qkv_einsum_w1); + GEMMA_CALL_FUNC(qkv_einsum_w2); + } + } else { + GEMMA_CALL_FUNC(griffin.linear_x_w); + GEMMA_CALL_FUNC(griffin.linear_x_biases); + GEMMA_CALL_FUNC(griffin.linear_y_w); + GEMMA_CALL_FUNC(griffin.linear_y_biases); + GEMMA_CALL_FUNC(griffin.linear_out_w); + GEMMA_CALL_FUNC(griffin.linear_out_biases); + GEMMA_CALL_FUNC(griffin.conv_w); + GEMMA_CALL_FUNC(griffin.conv_biases); + GEMMA_CALL_FUNC(griffin.gate_w); + GEMMA_CALL_FUNC(griffin.gate_biases); + GEMMA_CALL_FUNC(griffin.a); + } + GEMMA_CALL_FUNC(gating_einsum_w); + if (fet == ForEachType::kIgnoreNulls || fet == ForEachType::kLoadWithToc) { + // The unwanted ones will be null or not in the toc. + GEMMA_CALL_FUNC(gating_einsum_w1); + GEMMA_CALL_FUNC(gating_einsum_w2); + } + GEMMA_CALL_FUNC(linear_w); + GEMMA_CALL_FUNC(pre_attention_norm_scale); + GEMMA_CALL_FUNC(pre_ffw_norm_scale); + + if (TConfig::kPostNorm == PostNormType::Scale) { + GEMMA_CALL_FUNC(post_attention_norm_scale); + GEMMA_CALL_FUNC(post_ffw_norm_scale); + } + + if (TConfig::kFFBiases) { + GEMMA_CALL_FUNC(ffw_gating_biases); + GEMMA_CALL_FUNC(ffw_output_biases); + } + + if (TConfig::kSoftmaxAttnOutputBiases && + type == LayerAttentionType::kGemma) { + GEMMA_CALL_FUNC(attention_output_biases); } } - using CLayer = CompressedLayer; - std::array, TConfig::kLayers> c_layers; - using CVitLayer = CompressedLayer; - std::array, - TConfig::VitConfig::kLayers> - c_vit_layers; + // Sets all the tensors in the layer to zero. Memory must have been allocated. + void ZeroInit(int layer_idx) { + ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls, + [](const char*, hwy::Span tensors) { + tensors[0]->ZeroInit(); + }); + } + + // Allocates memory for all the tensors in the layer. + // Note that this is slow and only used for a stand-alone layer. + void Allocate() { + layer_storage.clear(); + ForEachTensor({this}, /*layer_idx=*/0, ForEachType::kInitNoToc, + [this](const char* name, hwy::Span tensors) { + this->layer_storage.emplace_back(*tensors[0]); + layer_storage.back().Allocate(); + tensors[0]->SetPtr(layer_storage.back()); + }); + } + + // Storage for all the matrices and vectors. Only used for a stand-alone + // layer. For a model, the CompressedWeights::model_storage is used instead. + std::vector layer_storage; }; template struct CompressedWeights { - // Must be allocated via AllocateAligned and initialized with placement new. - void* operator new(size_t, void* addr) { return addr; } - void* operator new(size_t) = delete; - void* operator new[](size_t) = delete; - void operator delete(void*) = delete; - void operator delete[](void*) = delete; + explicit CompressedWeights(hwy::ThreadPool& pool) + : embedder_input_embedding("c_embedding", TConfig::kVocabSize, + TConfig::kModelDim), + final_norm_scale("c_final_norm", 1, TConfig::kModelDim), + vit_encoder_norm_bias("c_vit_encoder_norm_bias", 1, + TConfig::VitConfig::kModelDim), + vit_encoder_norm_scale("c_vit_encoder_norm_scale", 1, + TConfig::VitConfig::kModelDim), + vit_img_embedding_bias("c_vit_img_embedding_bias", 1, + TConfig::VitConfig::kModelDim), + vit_img_embedding_kernel("c_vit_img_embedding_kernel", 14 * 14 * 3, + TConfig::VitConfig::kModelDim), + vit_img_pos_embedding("c_vit_img_pos_embedding", 256, + TConfig::VitConfig::kModelDim), + vit_img_head_bias("c_vit_img_head_bias", 1, TConfig::kModelDim), + vit_img_head_kernel("c_vit_img_head_kernel", + TConfig::VitConfig::kModelDim, TConfig::kModelDim), + scale_names({"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", + "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}) {} + + ~CompressedWeights() = default; using Weight = typename TConfig::Weight; - + using WeightF32OrBF16 = typename CompressedLayer::WeightF32OrBF16; using WeightF32OrInputT = - hwy::If(), float, EmbedderInputT>; - CompressedArray - embedder_input_embedding; + hwy::If(), EmbedderInputT, + WeightF32OrBF16>; - using WeightF32OrBF16 = - hwy::If(), float, hwy::bfloat16_t>; - CompressedArray final_norm_scale; + MatPtrT embedder_input_embedding; + MatPtrT final_norm_scale; // Vit parts. - CompressedArray - vit_encoder_norm_bias; - CompressedArray - vit_encoder_norm_scale; - CompressedArray vit_img_embedding_bias; - CompressedArray - vit_img_embedding_kernel; - CompressedArray - vit_img_pos_embedding; + MatPtrT vit_encoder_norm_bias; + MatPtrT vit_encoder_norm_scale; + MatPtrT vit_img_embedding_bias; + MatPtrT vit_img_embedding_kernel; + MatPtrT vit_img_pos_embedding; // The head maps from VitConfig::kModelDim (Vit final layer) to // kModelDim (LLM input). - CompressedArray vit_img_head_bias; - CompressedArray - vit_img_head_kernel; + MatPtrT vit_img_head_bias; + MatPtrT vit_img_head_kernel; - // Must be last so that the other arrays remain aligned. - CompressedLayerPointers c_layer_ptrs; + // Storage for all the matrices and vectors. + std::vector model_storage; + std::unordered_set scale_names; - explicit CompressedWeights(hwy::ThreadPool& pool) - : c_layer_ptrs(pool) - {} + CompressedLayer c_layers[TConfig::kLayers]; + CompressedLayer + vit_layers[TConfig::VitConfig::kLayers]; // Called by weights.cc after ForEachTensor. void Reshape(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::kLayers, [this](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Reshape(); - }); + size_t storage_index = model_storage.size(); + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + model_storage.emplace_back(GetLayer(layer)->att_weights); + } + pool.Run(0, TConfig::kLayers, + [this, storage_index](uint64_t layer, size_t /*thread*/) { + GetLayer(layer)->Reshape(model_storage[storage_index + layer]); + }); } void ZeroInit() { - hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding)); - hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale)); - hwy::ZeroBytes(&vit_encoder_norm_bias, sizeof(vit_encoder_norm_bias)); - hwy::ZeroBytes(&vit_encoder_norm_scale, sizeof(vit_encoder_norm_scale)); - hwy::ZeroBytes(&vit_img_embedding_bias, sizeof(vit_img_embedding_bias)); - hwy::ZeroBytes(&vit_img_embedding_kernel, sizeof(vit_img_embedding_kernel)); - hwy::ZeroBytes(&vit_img_head_bias, sizeof(vit_img_head_bias)); - hwy::ZeroBytes(&vit_img_head_kernel, sizeof(vit_img_head_kernel)); - hwy::ZeroBytes(&vit_img_pos_embedding, sizeof(vit_img_pos_embedding)); + embedder_input_embedding.ZeroInit(); + final_norm_scale.ZeroInit(); for (int i = 0; i < TConfig::kLayers; ++i) { - hwy::ZeroBytes(GetLayer(i), sizeof(*GetLayer(i))); - } - if constexpr (TConfig::VitConfig::kLayers > 0) { - for (int i = 0; i < TConfig::VitConfig::kLayers; ++i) { - hwy::ZeroBytes(GetVitLayer(i), sizeof(*GetVitLayer(i))); - } + c_layers[i].ZeroInit(i); } } const CompressedLayer* GetLayer(size_t layer) const { - return c_layer_ptrs.c_layers[layer].get(); - } - CompressedLayer* GetLayer(size_t layer) { - return c_layer_ptrs.c_layers[layer].get(); + return &c_layers[layer]; } + CompressedLayer* GetLayer(size_t layer) { return &c_layers[layer]; } const CompressedLayer* GetVitLayer( size_t layer) const { - return c_layer_ptrs.c_vit_layers[layer].get(); + return &vit_layers[layer]; } CompressedLayer* GetVitLayer(size_t layer) { - return c_layer_ptrs.c_vit_layers[layer].get(); + return &vit_layers[layer]; } + + // Copies the data from other to *this. + void CopyFrom(const CompressedWeights& other) { + ForEachTensor({this, const_cast*>(&other)}, + ForEachType::kIgnoreNulls, + [](const char*, hwy::Span tensors) { + hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(), + tensors[1]->SizeBytes()); + }); + } + + // If scales is empty, computes and returns the scale factors for the tensors, + // otherwise applies the scale factors to the tensors. + void GetOrApplyScales(std::vector& scales) { + int scale_pos = 0; + ForEachTensor( + {this}, ForEachType::kIgnoreNulls, + [&scales, &scale_pos, this](const char*, hwy::Span tensors) { + if (this->scale_names.count(tensors[0]->Name())) { + if (scale_pos < scales.size()) { + tensors[0]->set_scale(scales[scale_pos]); + } else { + float scale = ScaleWeights(tensors[0]->data(), + tensors[0]->NumElements()); + scales.push_back(scale); + } + ++scale_pos; + } + }); + HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); + } + + template + static void ForEachTensor( + const std::vector*>& ptrs, ForEachType fet, + Func func) { + std::vector*> layers(ptrs.size()); + std::vector*> vit_layers( + ptrs.size()); + MatPtr* tensors[ptrs.size()]; + // Variables used by GEMMA_CALL_FUNC. + int layer_idx = -1; + char sep = ' '; + int sep_index = -1; + GEMMA_CALL_FUNC(embedder_input_embedding); + GEMMA_CALL_FUNC(final_norm_scale); + if constexpr (TConfig::VitConfig::kLayers > 0) { + // Vit parts. + GEMMA_CALL_FUNC(vit_encoder_norm_bias); + GEMMA_CALL_FUNC(vit_encoder_norm_scale); + GEMMA_CALL_FUNC(vit_img_embedding_bias); + GEMMA_CALL_FUNC(vit_img_embedding_kernel); + GEMMA_CALL_FUNC(vit_img_pos_embedding); + GEMMA_CALL_FUNC(vit_img_head_bias); + GEMMA_CALL_FUNC(vit_img_head_kernel); + } + + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + for (int i = 0; i < ptrs.size(); ++i) { + layers[i] = ptrs[i]->GetLayer(layer_idx); + } + CompressedLayer::ForEachTensor(layers, layer_idx, fet, func); + } + + // Vit layers. Not supported for compress_weights. + if constexpr (TConfig::VitConfig::kLayers > 0) { + for (int layer_idx = 0; layer_idx < TConfig::VitConfig::kLayers; + ++layer_idx) { + auto type = TConfig::VitConfig::kLayerConfig[layer_idx]; + HWY_ASSERT(type == LayerAttentionType::kVit); + for (int i = 0; i < ptrs.size(); ++i) { + vit_layers[i] = ptrs[i]->GetVitLayer(layer_idx); + } + CompressedLayer::ForEachTensor( + vit_layers, layer_idx, fet, func); + } + } + } +}; +#undef GEMMA_CALL_FUNC + +// Pair of configs for the compressed and uncompressed weights. +template +struct ConfigPair { + using uc = UCConfig; + using c = CConfig; }; // ---------------------------------------------------------------------------- @@ -263,6 +518,20 @@ struct AllocateCompressedWeights { ByteStorageT weights_u8 = AllocateSizeof(); TWeights* weights = reinterpret_cast(weights_u8.get()); new (weights) TWeights(pool); + std::vector model_toc; + auto& model_storage = weights->model_storage; + TWeights::ForEachTensor( + {weights}, ForEachType::kInitNoToc, + [&model_toc, &model_storage](const char*, hwy::Span tensors) { + model_toc.push_back(tensors[0]); + model_storage.emplace_back(*tensors[0]); + }); + // Allocate in parallel using the pool. + pool.Run(0, model_storage.size(), + [&model_toc, &model_storage](uint64_t task, size_t /*thread*/) { + model_storage[task].Allocate(); + model_toc[task]->SetPtr(model_storage[task]); + }); return weights_u8; } }; @@ -287,291 +556,11 @@ struct ReshapeCompressedWeights { // TODO: also add RandInitCompressedWeights -template -struct DeleteCompressedWeights { - void operator()(ByteStorageT& weights_u8) const { - CompressedWeights& weights = - *reinterpret_cast*>(weights_u8.get()); - weights.~CompressedWeights(); - } -}; - ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, Type weight_type, hwy::ThreadPool& pool); void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); -// ---------------------------------------------------------------------------- -// Iterators - -// We rely on `if constexpr` to ensure raw_weights->member is only compiled -// when valid, i.e., kHaveRaw == true, but the IDE analysis does not understand -// this, hence hide the member access from it. -#if HWY_IDE -#define GEMMA_MEMBER(aggregate, member) nullptr -#else -#define GEMMA_MEMBER(aggregate, member) aggregate->member -#endif - -// Used by ForEachTensor for tensors that are not in a layer. -#define GEMMA_CALL_TOP_FUNC(name, member) \ - { \ - const float* raw_tensor = nullptr; \ - if constexpr (kHaveRaw) { \ - raw_tensor = GEMMA_MEMBER(raw_weights, member.data()); \ - } \ - func(name, raw_tensor, c_weights.member); \ - } - -// Used by ForEachTensor for per-layer tensors. Writes into name_buf. -#define GEMMA_CALL_FUNC(name, member) \ - snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - { \ - const float* raw_tensor = nullptr; \ - if constexpr (kHaveRaw) { \ - raw_tensor = GEMMA_MEMBER(raw_layer, member.data()); \ - } \ - func(name_buf, raw_tensor, c_layer->member); \ - } - -// Calls func(name, float*, CompressedArray&) for each tensor. float* is -// null if raw_weights is nullptr, e.g., when loading weights from BlobStore. -// Otherwise, RawLayer must be specified and we pass a float* pointing to the -// raw float weights for that tensor for use by compress_weights.cc. -// -// This avoids repeating the list of tensors between loading and compressing, -// while also avoiding dependency on raw_weights.h. -// -// This only calls Func for tensors that TConfig requests/specifies, which means -// scale() is uninitialized for the other tensors, so their data_scale1() must -// not be called. (In other words, if the config doesn't specify a tensor, it -// shouldn't be used.) -template -void ForEachTensor(RawWeightsPtr raw_weights, - CompressedWeights& c_weights, Func& func) { - constexpr bool kHaveRaw = !hwy::IsSame(); - - GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding); - GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale); - - if constexpr (TConfig::VitConfig::kLayers > 0 && !kHaveRaw) { - GEMMA_CALL_TOP_FUNC("enc_norm_bias", vit_encoder_norm_bias); - GEMMA_CALL_TOP_FUNC("enc_norm_scale", vit_encoder_norm_scale); - GEMMA_CALL_TOP_FUNC("img_emb_bias", vit_img_embedding_bias); - GEMMA_CALL_TOP_FUNC("img_emb_kernel", vit_img_embedding_kernel); - GEMMA_CALL_TOP_FUNC("img_head_bias", vit_img_head_bias); - GEMMA_CALL_TOP_FUNC("img_head_kernel", vit_img_head_kernel); - GEMMA_CALL_TOP_FUNC("img_pos_emb", vit_img_pos_embedding); - } - - char name_buf[16]; - for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - auto type = TConfig::kLayerConfig[layer_idx]; - const size_t idx = static_cast(layer_idx); - const RawLayer* raw_layer = nullptr; - if constexpr (kHaveRaw) { - raw_layer = raw_weights->GetLayer(idx); - } - CompressedLayer* c_layer = c_weights.GetLayer(idx); - - GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale); - GEMMA_CALL_FUNC("gating_ein", gating_einsum_w); - GEMMA_CALL_FUNC("linear_w", linear_w); - if (type == LayerAttentionType::kGemma) { - GEMMA_CALL_FUNC("qkv_ein", qkv_einsum_w); - GEMMA_CALL_FUNC("att_ein", attn_vec_einsum_w); - } else { - GEMMA_CALL_FUNC("gr_lin_x_w", griffin.linear_x_w); - GEMMA_CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases); - GEMMA_CALL_FUNC("gr_lin_y_w", griffin.linear_y_w); - GEMMA_CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases); - GEMMA_CALL_FUNC("gr_lin_out_w", griffin.linear_out_w); - GEMMA_CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases); - GEMMA_CALL_FUNC("gr_conv_w", griffin.conv_w); - GEMMA_CALL_FUNC("gr_conv_b", griffin.conv_biases); - GEMMA_CALL_FUNC("gr_gate_w", griffin.gate_w); - GEMMA_CALL_FUNC("gr_gate_b", griffin.gate_biases); - GEMMA_CALL_FUNC("gr_a", griffin.a); - } - GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); - - if (TConfig::kPostNorm == PostNormType::Scale) { - GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); - GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); - } - - if (TConfig::kFFBiases) { - GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases); - GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases); - } - - if (TConfig::kSoftmaxAttnOutputBiases && - type == LayerAttentionType::kGemma) { - GEMMA_CALL_FUNC("attn_ob", attention_output_biases); - } - } - - // Vit layers. Not supported for compress_weights. - if constexpr (TConfig::VitConfig::kLayers > 0 && !kHaveRaw) { - for (int layer_idx = 0; layer_idx < TConfig::VitConfig::kLayers; - ++layer_idx) { - auto type = TConfig::VitConfig::kLayerConfig[layer_idx]; - HWY_ASSERT(type == LayerAttentionType::kVit); - const size_t idx = static_cast(layer_idx); - const RawLayer* raw_layer = nullptr; - CompressedLayer* c_layer = - c_weights.GetVitLayer(idx); - - // MHA. - GEMMA_CALL_FUNC("attn_out_w", vit.attn_out_w); - GEMMA_CALL_FUNC("attn_out_b", vit.attn_out_b); - GEMMA_CALL_FUNC("qkv_ein_w", vit.qkv_einsum_w); - GEMMA_CALL_FUNC("qkv_ein_b", vit.qkv_einsum_b); - // MlpBlock. - GEMMA_CALL_FUNC("linear_0_w", vit.linear_0_w); - GEMMA_CALL_FUNC("linear_0_b", vit.linear_0_b); - GEMMA_CALL_FUNC("linear_1_w", vit.linear_1_w); - GEMMA_CALL_FUNC("linear_1_b", vit.linear_1_b); - // LayerNorm. - GEMMA_CALL_FUNC("ln_0_bias", vit.layer_norm_0_bias); - GEMMA_CALL_FUNC("ln_0_scale", vit.layer_norm_0_scale); - GEMMA_CALL_FUNC("ln_1_bias", vit.layer_norm_1_bias); - GEMMA_CALL_FUNC("ln_1_scale", vit.layer_norm_1_scale); - } - } -#undef GEMMA_CALL_FUNC -#undef GEMMA_CALL_TOP_FUNC -} // ForEachTensor - -#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member) -#define GEMMA_CALL_TOP_FUNC2(name, member) \ - func(name, weights1.member, weights2.member) -#define GEMMA_CALL_TOP_FUNC3(name, member) \ - func(name, weights1.member, weights2.member, weights3.member) -#define GEMMA_CALL_TOP_FUNC4(name, member) \ - func(name, weights1.member, weights2.member, \ - weights3.member, weights4.member) - -#define GEMMA_CALL_LAYER_FUNC1(name, member) \ - snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer1.member) - -#define GEMMA_CALL_LAYER_FUNC2(name, member) \ - snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer1.member, layer2.member) - -#define GEMMA_CALL_LAYER_FUNC3(name, member) \ - snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer1.member, layer2.member, layer3.member) - -#define GEMMA_CALL_LAYER_FUNC4(name, member) \ - snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer1.member, layer2.member, layer3.member, layer4.member) - -#define GEMMA_CALL_ALL_LAYER_FUNC(N) \ - if (type == LayerAttentionType::kGemma) { \ - GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ - GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ - } else { \ - GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ - GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ - } \ - GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ - GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ - GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ - if (TConfig::kPostNorm == PostNormType::Scale) { \ - GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ - GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ - } \ - GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ - if (TConfig::kFFBiases) { \ - GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ - GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ - } \ - if (TConfig::kSoftmaxAttnOutputBiases && \ - type == LayerAttentionType::kGemma) { \ - GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ - } - -template -void ForEachTensor1(Func& func, const CompressedWeights& weights1) { - GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); - GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); - char name_buf[16]; - for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - auto type = TConfig::kLayerConfig[layer_idx]; - const size_t idx = static_cast(layer_idx); - const CompressedLayer& layer1 = *weights1.GetLayer(idx); - GEMMA_CALL_ALL_LAYER_FUNC(1) - } -} - -template -void ForEachTensor1(Func& func, CompressedWeights& weights1) { - GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); - GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); - char name_buf[16]; - for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - auto type = TConfig::kLayerConfig[layer_idx]; - const size_t idx = static_cast(layer_idx); - CompressedLayer& layer1 = *weights1.GetLayer(idx); - GEMMA_CALL_ALL_LAYER_FUNC(1) - } -} - -template -void ForEachTensor2(Func& func, const CompressedWeights& weights1, - CompressedWeights& weights2) { - GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding); - GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale); - char name_buf[16]; - for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - auto type = TConfig::kLayerConfig[layer_idx]; - const size_t idx = static_cast(layer_idx); - const CompressedLayer& layer1 = *weights1.GetLayer(idx); - CompressedLayer& layer2 = *weights2.GetLayer(idx); - GEMMA_CALL_ALL_LAYER_FUNC(2) - } -} - -template -void ForEachTensor4(Func& func, const CompressedWeights& weights1, - CompressedWeights& weights2, - CompressedWeights& weights3, - CompressedWeights& weights4) { - GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding); - GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale); - char name_buf[16]; - for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { - auto type = TConfig::kLayerConfig[layer_idx]; - const size_t idx = static_cast(layer_idx); - const CompressedLayer& layer1 = *weights1.GetLayer(idx); - CompressedLayer& layer2 = *weights2.GetLayer(idx); - CompressedLayer& layer3 = *weights3.GetLayer(idx); - CompressedLayer& layer4 = *weights4.GetLayer(idx); - GEMMA_CALL_ALL_LAYER_FUNC(4) - } -} - -#undef GEMMA_CALL_TOP_FUNC1 -#undef GEMMA_CALL_TOP_FUNC2 -#undef GEMMA_CALL_TOP_FUNC3 -#undef GEMMA_CALL_TOP_FUNC4 -#undef GEMMA_CALL_LAYER_FUNC1 -#undef GEMMA_CALL_LAYER_FUNC2 -#undef GEMMA_CALL_LAYER_FUNC3 -#undef GEMMA_CALL_LAYER_FUNC4 -#undef GEMMA_CALL_ALL_LAYER_FUNC - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ diff --git a/ops/dot-inl.h b/ops/dot-inl.h index 82935f7..012a956 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -377,20 +377,23 @@ HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) { } // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. -template -HWY_INLINE float Dot(const std::array& w, size_t w_ofs, - const VT* vec, size_t num) { +template +HWY_INLINE float Dot(const CompressedArray& w, size_t w_ofs, + const VT* vec_aligned, size_t num) { const hn::ScalableTag d; - return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num); + return w.scale() * + Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num); } // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. -template -HWY_INLINE float Dot(const CompressedArray& w, size_t w_ofs, - const VT* vec, size_t num) { +template +HWY_INLINE float Dot(const MatPtrT& w, size_t w_ofs, + const VT* vec_aligned, size_t num) { const hn::ScalableTag d; - return w.scale() * - Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num); + return w.scale() * Dot(d, + MakeConstSpan(reinterpret_cast(w.Ptr()), + w.NumElements()), + w_ofs, vec_aligned, num); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/gemma_matvec_test.cc b/ops/gemma_matvec_test.cc index daac9b0..bfc515c 100644 --- a/ops/gemma_matvec_test.cc +++ b/ops/gemma_matvec_test.cc @@ -46,18 +46,18 @@ namespace HWY_NAMESPACE { using FloatPtr = hwy::AlignedFreeUniquePtr; -template -FloatPtr SimpleMatVecAdd(const CompressedArray& mat, - const FloatPtr& vec, const FloatPtr& add) { - FloatPtr raw_mat = hwy::AllocateAligned(kNum); - FloatPtr out = hwy::AllocateAligned(kOuter); +FloatPtr SimpleMatVecAdd(const MatStorageT& mat, const FloatPtr& vec, + const FloatPtr& add) { + FloatPtr raw_mat = hwy::AllocateAligned(mat.NumElements()); + FloatPtr out = hwy::AllocateAligned(mat.Rows()); HWY_ASSERT(raw_mat && out); const hn::ScalableTag df; - DecompressAndZeroPad(df, MakeSpan(mat.data(), kNum), 0, raw_mat.get(), kNum); - for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { + DecompressAndZeroPad(df, MakeSpan(mat.data(), mat.NumElements()), 0, + raw_mat.get(), mat.NumElements()); + for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) { out[idx_row] = 0.0f; - for (size_t idx_col = 0; idx_col < kInner; idx_col++) { - out[idx_row] += raw_mat[kInner * idx_row + idx_col] * vec[idx_col]; + for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) { + out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col]; } out[idx_row] *= mat.scale(); out[idx_row] += add[idx_row]; @@ -65,13 +65,12 @@ FloatPtr SimpleMatVecAdd(const CompressedArray& mat, return out; } -template >> -MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { +template +std::unique_ptr> GenerateMat(size_t offset, + hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - MatPtr mat = std::make_unique>(); - FloatPtr raw_mat = hwy::AllocateAligned(kNum); + auto mat = std::make_unique>("TestMat", kOuter, kInner); + FloatPtr raw_mat = hwy::AllocateAligned(mat->NumElements()); HWY_ASSERT(raw_mat); const float scale = 1.0f / kInner; pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { @@ -81,7 +80,7 @@ MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { } }); - CompressScaled(raw_mat.get(), kNum, ws, *mat, pool); + CompressScaled(raw_mat.get(), mat->NumElements(), ws, *mat, pool); mat->set_scale(1.9f); // Arbitrary value, different from 1. return mat; } @@ -113,7 +112,7 @@ void TestMatVecAdd() { auto mat = GenerateMat(0, pool); FloatPtr vec = GenerateVec(0); FloatPtr add = GenerateVec(0); - FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add); + FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add); FloatPtr actual_out = hwy::AllocateAligned(kOuter); HWY_ASSERT(vec && add && expected_out && actual_out); MatVecAdd(*mat, 0, vec.get(), add.get(), actual_out.get(), @@ -130,8 +129,8 @@ void TestTwoMatVecAdd() { FloatPtr vec = GenerateVec(0); FloatPtr add0 = GenerateVec(0); FloatPtr add1 = GenerateVec(1); - FloatPtr expected_out0 = SimpleMatVecAdd(*mat0, vec, add0); - FloatPtr expected_out1 = SimpleMatVecAdd(*mat1, vec, add1); + FloatPtr expected_out0 = SimpleMatVecAdd(*mat0, vec, add0); + FloatPtr expected_out1 = SimpleMatVecAdd(*mat1, vec, add1); FloatPtr actual_out0 = hwy::AllocateAligned(kOuter); FloatPtr actual_out1 = hwy::AllocateAligned(kOuter); HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && @@ -151,8 +150,8 @@ void TestTwoOfsMatVecAddLoop() { FloatPtr vec = GenerateVec(0); FloatPtr add0 = GenerateVec(0); FloatPtr add1 = GenerateVec(1); - FloatPtr expected_out0 = SimpleMatVecAdd(*mat, vec, add0); - FloatPtr expected_out1 = SimpleMatVecAdd(*mat, vec, add1); + FloatPtr expected_out0 = SimpleMatVecAdd(*mat, vec, add0); + FloatPtr expected_out1 = SimpleMatVecAdd(*mat, vec, add1); FloatPtr actual_out0 = hwy::AllocateAligned(kOuter); FloatPtr actual_out1 = hwy::AllocateAligned(kOuter); HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && diff --git a/ops/matmul.h b/ops/matmul.h index ecc72b1..4ef63bc 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -28,6 +28,8 @@ namespace gcpp { // Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be // const or non-const. Create via ConstMat/MutableMat. +// TODO(rays): Replace with MatPtr and get rid of stride, which is only != cols +// in one place. template struct Mat { bool NotEmpty() const { diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 2b64f27..b6445b3 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -52,13 +52,13 @@ using FloatPtr = hwy::AlignedFreeUniquePtr; // Generates inputs: deterministic, within max SfpStream range. template >> + class MatPtr = std::unique_ptr>> MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - FloatPtr content = hwy::AllocateAligned(kNum); + auto mat = std::make_unique>("test", kRows, kCols); + FloatPtr content = hwy::AllocateAligned(mat->NumElements()); HWY_ASSERT(content); - const float scale = SfpStream::kMax / (kNum + offset); + const float scale = SfpStream::kMax / (mat->NumElements() + offset); pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) { for (size_t j = 0; j < kCols; j++) { content[i * kCols + j] = @@ -66,19 +66,18 @@ MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { } }); - MatPtr mat = std::make_unique>(); - CompressScaled(content.get(), kNum, ws, *mat, pool); + CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); mat->set_scale(0.6f); // Arbitrary value, different from 1. return mat; } template >> + class MatPtr = std::unique_ptr>> MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - FloatPtr content = hwy::AllocateAligned(kNum); - const float scale = SfpStream::kMax / (kNum + offset); + MatPtr mat = std::make_unique>("test", kCols, kRows); + FloatPtr content = hwy::AllocateAligned(mat->NumElements()); + const float scale = SfpStream::kMax / (mat->NumElements() + offset); pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) { for (size_t j = 0; j < kCols; j++) { content[j * kRows + i] = @@ -86,27 +85,25 @@ MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) { } }); - MatPtr mat = std::make_unique>(); - CompressScaled(content.get(), kNum, ws, *mat, pool); + CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); // Arbitrary value, different from 1, must match GenerateMatHeap. mat->set_scale(0.6f); return mat; } template >> + class MatPtr = std::unique_ptr>> MatPtr GenerateZeroMat(hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - FloatPtr content = hwy::AllocateAligned(kNum); + auto mat = std::make_unique>("Array", kRows, kCols); + FloatPtr content = hwy::AllocateAligned(mat->NumElements()); HWY_ASSERT(content); pool.Run(0, kRows, [&](const size_t i, size_t thread) { hwy::ZeroBytes(&content[i * kCols], kCols * sizeof(content[0])); }); - MatPtr mat = std::make_unique>(); - CompressScaled(content.get(), kNum, ws, *mat, pool); + CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); mat->set_scale(1.2f); // Arbitrary value, different from 1. return mat; } @@ -216,21 +213,21 @@ void TestMatMul(MatMulEnv& env) { kRowsAC, kColsARowsB, kColsBC, kAdd, TypeName(), TypeName()); - std::unique_ptr> a = + std::unique_ptr> a = GenerateMat(0, pool); - std::unique_ptr> b_trans = + std::unique_ptr> b_trans = GenerateTransposedMat(0, pool); FloatPtr c = hwy::AllocateAligned(kRowsAC * kColsBC); HWY_ASSERT(c); const float scale = a->scale() * b_trans->scale(); - std::unique_ptr> add; + std::unique_ptr> add; if (kAdd) { add = GenerateMat(0, pool); add->set_scale(1.0f); } - std::unique_ptr> c_slow = + std::unique_ptr> c_slow = GenerateZeroMat(pool); const double start_slow = hwy::platform::Now(); MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 5a56413..5d629ac 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -214,9 +214,9 @@ HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, } // Two matrices, same vector -template -HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1, +template +HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0, @@ -254,10 +254,10 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1, } // With addition -template +template HWY_NOINLINE void TwoMatVecAdd( - const ArrayT& mat0, const ArrayT& mat1, const size_t mat_ofs, + const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0, const AddT* HWY_RESTRICT add1, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, hwy::ThreadPool& pool) { @@ -266,13 +266,14 @@ HWY_NOINLINE void TwoMatVecAdd( } // Without addition -template -HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1, +template +HWY_NOINLINE void TwoMatVec(const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs, const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, hwy::ThreadPool& pool) { - TwoMatVecT( + TwoMatVecT( mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, out0, out1, pool); } diff --git a/paligemma/image.cc b/paligemma/image.cc index 309296c..5dad770 100644 --- a/paligemma/image.cc +++ b/paligemma/image.cc @@ -146,7 +146,7 @@ bool Image::WriteBinary(const std::string& filename) const { std::cerr << "Failed to open " << filename << "\n"; return false; } - for (int i = 0; i < data_.size(); ++i) { + for (size_t i = 0; i < data_.size(); ++i) { file.write(reinterpret_cast(&data_[i]), sizeof(float)); } file.close();