From 301dc8067aa134f9e40dff1f370de9a1fcd750c2 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 16 Aug 2024 07:51:40 -0700 Subject: [PATCH] Major MatMul update, 1.9-2.3x speedup on Zen4 via bf16 mul Supports converting all weight/activation formats to native MulT (bf16/f32) Also: - ConstMat/MutableMat for const correctness - Move RowVectorBatch to allocator.h so it can be used from Matmul - Add matmul.h so MatMulEnv can be used from Activations - Remove kMaxThreads, detect from PerClusterPools - Build fix: -inl.h files must be textual_hdrs, and highway.h should precede -inl.h ``` zen4 new 64, 24576, 3072, add=0, MatTA=bf16, MatTB=sfp: 616.6 GFLOPS. 64, 3072, 24576, add=0, MatTA=bf16, MatTB=sfp: 460.7 GFLOPS. 64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp: 598.6 GFLOPS. 64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp: 435.6 GFLOPS. zen4 old 64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp: 257.5 GFLOPS. 64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp: 231.9 GFLOPS. ``` PiperOrigin-RevId: 663729812 --- BUILD.bazel | 51 ++- CMakeLists.txt | 2 + backprop/activations.h | 2 +- backprop/backward-inl.h | 4 +- backprop/backward.cc | 4 +- backprop/optimizer.h | 1 + compression/BUILD | 1 + compression/compress-inl.h | 36 +- compression/sfp-inl.h | 11 +- compression/weights_raw.h | 2 +- gemma/activations.h | 56 +-- gemma/common.h | 9 - gemma/configs.h | 6 - gemma/gemma-inl.h | 304 ++++++--------- gemma/weights.h | 1 + ops/matmul-inl.h | 770 +++++++++++++++++++------------------ ops/matmul.h | 97 +++++ ops/matmul_test.cc | 112 +++--- util/allocator.h | 75 ++++ util/threading.h | 5 + 20 files changed, 862 insertions(+), 687 deletions(-) create mode 100644 ops/matmul.h create mode 100644 util/allocator.h diff --git a/BUILD.bazel b/BUILD.bazel index a2b647a..a6a9534 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -20,14 +20,37 @@ licenses(["notice"]) exports_files(["LICENSE"]) +cc_library( + name = "allocator", + hdrs = ["util/allocator.h"], + deps = [ + "@hwy//:hwy", + ], +) + +cc_library( + name = "threading", + hdrs = ["util/threading.h"], + deps = [ + "@hwy//:hwy", + "@hwy//:thread_pool", + "@hwy//:topology", + ], +) + cc_library( name = "ops", + hdrs = [ + "ops/matmul.h", + ], textual_hdrs = [ "ops/ops-inl.h", "ops/matmul-inl.h", "ops/matvec-inl.h", ], deps = [ + ":allocator", + ":threading", "//compression:compress", "//compression:sfp", "@hwy//:algo", @@ -86,6 +109,7 @@ cc_test( tags = ["hwy_ops_test"], deps = [ ":ops", + ":threading", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", "@hwy//:hwy", @@ -114,6 +138,7 @@ cc_library( srcs = ["gemma/weights.cc"], hdrs = ["gemma/weights.h"], deps = [ + ":allocator", ":common", "//compression:compress", "//compression:io", @@ -148,16 +173,6 @@ cc_library( ], ) -cc_library( - name = "threading", - hdrs = ["util/threading.h"], - deps = [ - "@hwy//:hwy", - "@hwy//:thread_pool", - "@hwy//:topology", - ], -) - cc_library( name = "gemma_lib", srcs = [ @@ -197,6 +212,7 @@ cc_library( # Placeholder for internal file2, do not remove, ], deps = [ + ":allocator", ":common", ":ops", ":tokenizer", @@ -389,11 +405,14 @@ cc_library( hdrs = [ "backprop/activations.h", "backprop/backward.h", - "backprop/backward-inl.h", "backprop/forward.h", + ], + textual_hdrs = [ + "backprop/backward-inl.h", "backprop/forward-inl.h", ], deps = [ + ":allocator", ":common", ":gemma_lib", ":ops", @@ -413,6 +432,7 @@ cc_library( "backprop/forward_scalar.h", ], deps = [ + ":allocator", ":common", ":gemma_lib", ":prompt", @@ -467,13 +487,10 @@ cc_test( cc_library( name = "optimizer", - srcs = [ - "backprop/optimizer.cc", - ], - hdrs = [ - "backprop/optimizer.h", - ], + srcs = ["backprop/optimizer.cc"], + hdrs = ["backprop/optimizer.h"], deps = [ + ":allocator", ":common", ":weights", "//compression:compress", diff --git a/CMakeLists.txt b/CMakeLists.txt index a948783..72a9a9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,8 +103,10 @@ set(SOURCES ops/matmul-inl.h ops/matvec-inl.h ops/ops-inl.h + util/allocator.h util/app.h util/args.h + util/threading.h ) if(NOT CMAKE_BUILD_TYPE) diff --git a/backprop/activations.h b/backprop/activations.h index b3bb455..aee0341 100644 --- a/backprop/activations.h +++ b/backprop/activations.h @@ -20,7 +20,7 @@ #include -#include "gemma/common.h" // ByteStorageT +#include "util/allocator.h" // ByteStorageT namespace gcpp { diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 2b10cc4..41dc21c 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -27,7 +27,6 @@ #include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/activations.h" // CreateInvTimescale #include "gemma/common.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -42,9 +41,10 @@ #define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #endif +#include "hwy/highway.h" +// After highway.h #include "ops/matmul-inl.h" #include "ops/ops-inl.h" -#include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/backprop/backward.cc b/backprop/backward.cc index 27bbd0e..7f06fd4 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -18,6 +18,8 @@ #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/common.h" +#include "gemma/weights.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" // Compiles this file for multiple architectures via "foreach_target.h", to @@ -29,8 +31,6 @@ #include "hwy/highway.h" // After highway.h #include "backprop/backward-inl.h" -#include "gemma/activations.h" -#include "gemma/weights.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/backprop/optimizer.h b/backprop/optimizer.h index 9157fa8..b42f311 100644 --- a/backprop/optimizer.h +++ b/backprop/optimizer.h @@ -19,6 +19,7 @@ #include #include "gemma/common.h" +#include "util/allocator.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { diff --git a/compression/BUILD b/compression/BUILD index e5fc9ed..3334da8 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -185,6 +185,7 @@ cc_library( name = "weights_raw", hdrs = ["weights_raw.h"], deps = [ + "//:allocator", "//:common", "//compression:compress", "@hwy//:hwy", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 3da53ce..bfa9f1d 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -89,6 +89,22 @@ struct CompressTraits { f1 = hn::LoadU(df, in + in_ofs + N); } + // Called by MatMul for f32 weights or activations if native + // `ReorderWidenMulAccumulate` is available. + template > + static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in, + size_t in_ofs, VBF16& v0, VBF16& v1) { + const hn::Repartition df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + const VF f0 = hn::LoadU(df, in + in_ofs + 0 * NF); + const VF f1 = hn::LoadU(df, in + in_ofs + 1 * NF); + const VF f2 = hn::LoadU(df, in + in_ofs + 2 * NF); + const VF f3 = hn::LoadU(df, in + in_ofs + 3 * NF); + v0 = hn::OrderedDemote2To(dbf16, f0, f1); + v1 = hn::OrderedDemote2To(dbf16, f2, f3); + } + template static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, const MatT* HWY_RESTRICT in, size_t in_ofs, @@ -196,6 +212,14 @@ struct CompressTraits { f1 = hn::PromoteUpperTo(df, in16); } + template + static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in, + size_t in_ofs, hn::Vec& v0, + hn::Vec& v1) { + v0 = hn::LoadU(dbf16, in + in_ofs); + v1 = hn::LoadU(dbf16, in + in_ofs + hn::Lanes(dbf16)); + } + template static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, const MatT* HWY_RESTRICT in, size_t in_ofs, @@ -318,14 +342,14 @@ struct CompressTraits { } } - template - static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in, - size_t in_ofs, hn::Vec& f0, - hn::Vec& f1) { - const hn::Twice> d8; + template // f32 or bf16 + static HWY_INLINE void Decompress2(D d, const MatT* HWY_RESTRICT in, + size_t in_ofs, hn::Vec& v0, + hn::Vec& v1) { + const hn::Twice> d8; using V8 = hn::Vec; const V8 packed = hn::LoadU(d8, &in->byte + in_ofs); - SfpCodec::Dec2F(df, packed, f0, f1); + SfpCodec::Dec2(d, packed, v0, v1); } template diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 7ca877b..ea5fb4c 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -533,8 +533,8 @@ class SfpCodec { template >>> - static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec& f0, - hn::Vec& f1) { + static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec& f0, + hn::Vec& f1) { const hn::Rebind dbf; using VBF = hn::Vec; VBF bf0, bf1; @@ -543,6 +543,13 @@ class SfpCodec { f1 = hn::PromoteTo(df, bf1); } + template >> + static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec& bf0, + hn::Vec& bf1) { + Dec2B(dbf16, packed, bf0, bf1); + } + private: // Wrappers to avoid code duplication across float/bf16 input types and // the main loop/remainder. diff --git a/compression/weights_raw.h b/compression/weights_raw.h index 0819feb..c6e6a73 100644 --- a/compression/weights_raw.h +++ b/compression/weights_raw.h @@ -26,8 +26,8 @@ #include -#include "gemma/common.h" #include "gemma/configs.h" +#include "util/allocator.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/gemma/activations.h b/gemma/activations.h index 60d5610..f8cb4dd 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -20,51 +20,14 @@ #include -#include "gemma/common.h" // kMaxThreads - TODO: remove -#include "hwy/aligned_allocator.h" +#include "ops/matmul.h" // MatMulEnv +#include "util/allocator.h" // RowVectorBatch +#include "util/threading.h" #include "hwy/base.h" // HWY_DASSERT +#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -// Owns dynamically-allocated aligned memory for a batch of row vectors. -// This can be seen as a (batch_size x len) matrix. -template -class RowVectorBatch { - public: - // Default ctor for Activations ctor. - RowVectorBatch() : batch_size_(0), len_(0) {} - // Main ctor, called from Activations::Allocate. - RowVectorBatch(size_t batch_size, size_t len) - : batch_size_(batch_size), len_(len) { - mem_ = hwy::AllocateAligned(batch_size * len); - } - - // Move-only - RowVectorBatch(RowVectorBatch&) noexcept = delete; - RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; - RowVectorBatch(RowVectorBatch&&) noexcept = default; - RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; - - size_t BatchSize() const { return batch_size_; } - size_t Len() const { return len_; } - - // Returns the given row vector of length `Len()`. - T* Batch(size_t batch_idx) { - HWY_DASSERT(batch_idx < batch_size_); - return mem_.get() + batch_idx * len_; - } - - // For MatMul or other operations that process the entire batch at once. - T* All() { return mem_.get(); } - const T* Const() const { return mem_.get(); } - size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); } - - private: - hwy::AlignedFreeUniquePtr mem_; - size_t batch_size_; // rows in the matrix - size_t len_; // columns in the matrix = vector length -}; - struct Activations { RowVectorBatch x; // input RowVectorBatch q; // query, also KV if MHA. @@ -94,9 +57,11 @@ struct Activations { // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into // per-thread storage. - // TODO: remove once MatVec is gone. + // TODO: remove once MatVec is no longer used. RowVectorBatch even_odd; + MatMulEnv env; + // Multi-Head Attention? template static constexpr bool IsMHA() { @@ -126,7 +91,7 @@ struct Activations { } template - void Allocate(size_t batch_size) { + void Allocate(size_t batch_size, PerClusterPools& pools) { constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kQKVDim = TConfig::kQKVDim; constexpr size_t kHeads = TConfig::kHeads; @@ -158,7 +123,10 @@ struct Activations { inv_timescale = CreateInvTimescale(); - even_odd = RowVectorBatch(1, kModelDim * kMaxThreads); + const size_t num_lp = pools.NumLP(); + even_odd = RowVectorBatch(1, kModelDim * num_lp); + + env = MatMulEnv(pools); } }; diff --git a/gemma/common.h b/gemma/common.h index ea541d9..40b8c9f 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -18,24 +18,15 @@ #include // sqrtf #include -#include #include #include "compression/compress.h" #include "gemma/configs.h" // IWYU pragma: export -#include "hwy/aligned_allocator.h" #include "hwy/base.h" // ConvertScalarTo namespace gcpp { -using ByteStorageT = hwy::AlignedFreeUniquePtr; - -template -ByteStorageT AllocateSizeof() { - return hwy::AllocateAligned(sizeof(T)); -} - // Model variants: see configs.h for details. When adding a new one, also // update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc. enum class Model { diff --git a/gemma/configs.h b/gemma/configs.h index 3ddcb41..cb903fd 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -36,14 +36,8 @@ namespace gcpp { #define GEMMA_TOPK 1 #endif // !GEMMA_TOPK -// Allow changing upper bound on threads as a compiler flag -#ifndef GEMMA_MAX_THREADS -#define GEMMA_MAX_THREADS 128 -#endif // !GEMMA_MAX_THREADS - static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; static constexpr size_t kTopK = GEMMA_TOPK; -static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS; using EmbedderInputT = hwy::bfloat16_t; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 101adab..86a8193 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -41,6 +41,7 @@ #include "ops/matmul-inl.h" #include "ops/matvec-inl.h" #include "ops/ops-inl.h" +#include "util/allocator.h" #include "util/threading.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" @@ -73,9 +74,10 @@ template HWY_NOINLINE void GriffinRecurrent( size_t batch_start, size_t num_tokens, size_t layer, Activations& activations, const CompressedLayer* layer_weights, - const KVCaches& kv_caches, hwy::ThreadPool& pool) { + const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Griffin"); KVCache& kv_cache = kv_caches[0]; + hwy::ThreadPool& pool = activations.env.Pool(); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; static constexpr size_t kModelDim = TConfig::kModelDim; @@ -240,12 +242,12 @@ class GemmaAttention { // and kQStride = kQKVDim * (kIsMHA ? 3 : 1); const auto pre_att_rms_out = - MakeMat(activations_.pre_att_rms_out.All(), kModelDim); - MatMul_4x4( + ConstMat(activations_.pre_att_rms_out.All(), kModelDim); + MatMul( num_interleaved, pre_att_rms_out, - MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim), - layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, - MakeMat(activations_.q.All(), kHeads * kQStride), pool_); + ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim), + layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, activations_.env, + MutableMat(activations_.q.All(), kHeads * kQStride)); if constexpr (kIsMHA) { static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved"); @@ -259,12 +261,13 @@ class GemmaAttention { queries_pos_[0] * kCachePosSize + layer_ * kCacheLayerSize; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; - MatMul_4x4( + MatMul( num_tokens_, pre_att_rms_out, - MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim, - kHeads * kQKVDim * kModelDim), + ConstMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim, + kHeads * kQKVDim * kModelDim), layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, - MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool_); + activations_.env, + MutableMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize)); } else { // Proceed row by row because there will be wraparound. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; @@ -430,19 +433,18 @@ class GemmaAttention { // Thus the [num_interleaved, kModelDim] matmul output is the sum over // heads. Compare gemma/modules.py: // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) - MatMul_4x4( - num_interleaved, MakeMat(activations_.att_out.All(), kHeads * kQKVDim), - MakeMat(layer_weights_.att_weights.data(), kHeads * kQKVDim), - layer_weights_.attn_vec_einsum_w.scale(), bias, - MakeMat(activations_.att_sums.All(), kModelDim), pool_); + MatMul( + num_interleaved, ConstMat(activations_.att_out.All(), kHeads * kQKVDim), + ConstMat(layer_weights_.att_weights.data(), kHeads * kQKVDim), + layer_weights_.attn_vec_einsum_w.scale(), bias, activations_.env, + MutableMat(activations_.att_sums.All(), kModelDim)); } public: GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer, Activations& activations, const CompressedLayer* layer_weights, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - hwy::ThreadPool& pool) + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) : queries_pos_(queries_pos), num_queries_(queries_pos.size()), num_tokens_(num_tokens), @@ -451,7 +453,7 @@ class GemmaAttention { layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), - pool_(pool) { + pool_(activations.env.Pool()) { HWY_DASSERT(num_queries_ <= kv_caches_.size()); } @@ -480,17 +482,17 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t layer, Activations& activations, const CompressedLayer* layer_weights, const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches, hwy::ThreadPool& pool) { + const KVCaches& kv_caches) { if (type == LayerAttentionType::kGemma) { GemmaAttention(queries_pos, num_tokens, layer, activations, - layer_weights, div_seq_len, kv_caches, pool)(); + layer_weights, div_seq_len, kv_caches)(); } else { // Only reached if the model is Griffin. `if constexpr` prevents generating // this code for non-Griffin models. if constexpr (TConfig::kGriffinLayers > 0) { HWY_ASSERT(queries_pos.size() == 1); GriffinRecurrent(queries_pos[0], num_tokens, layer, activations, - layer_weights, kv_caches, pool); + layer_weights, kv_caches); } } } @@ -510,8 +512,7 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, template HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, - const CompressedLayer* layer_weights, - hwy::ThreadPool& pool) { + const CompressedLayer* layer_weights) { PROFILER_ZONE("Gen.FFW"); constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; @@ -519,10 +520,10 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, // MatMul expects col-major B, which is what we have: kModelDim consecutive // elements in memory, repeated kFFHiddenDim times. HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); - const auto A = MakeMat(activations.bf_pre_ffw_rms_out.All(), kModelDim); - const auto B1 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim); - const auto B2 = MakeMat(layer_weights->gating_einsum_w.data(), kModelDim, - kModelDim, kModelDim * kFFHiddenDim); + const auto A = ConstMat(activations.bf_pre_ffw_rms_out.All(), kModelDim); + const auto B1 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim); + const auto B2 = ConstMat(layer_weights->gating_einsum_w.data(), kModelDim, + kModelDim, kModelDim * kFFHiddenDim); const float scale = layer_weights->gating_einsum_w.scale(); constexpr bool kAddBias = TConfig::kFFBiases; const float* bias1 = nullptr; @@ -533,22 +534,23 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved, bias2 = bias1 + kFFHiddenDim; output_bias = layer_weights->ffw_output_biases.data_scale1(); } - auto C1 = MakeMat(activations.C1.All(), kFFHiddenDim); - auto C2 = MakeMat(activations.C2.All(), kFFHiddenDim); + auto C1 = MutableMat(activations.C1.All(), kFFHiddenDim); + auto C2 = MutableMat(activations.C2.All(), kFFHiddenDim); // Will go through GELU. - MatMul_4x4(num_interleaved, A, B1, scale, bias1, C1, pool); + MatMul(num_interleaved, A, B1, scale, bias1, activations.env, C1); // What to multiply by. - MatMul_4x4(num_interleaved, A, B2, scale, bias2, C2, pool); + MatMul(num_interleaved, A, B2, scale, bias2, activations.env, C2); // Activation (Gelu) and multiply by gate. Store activations in C1. Activation(C1.ptr, C2.ptr, kFFHiddenDim * num_interleaved); // Hidden layer -> output layer. - MatMul_4x4(num_interleaved, C1, - MakeMat(layer_weights->linear_w.data(), kFFHiddenDim), - layer_weights->linear_w.scale(), output_bias, - MakeMat(activations.ffw_out.All(), kModelDim), pool); + MatMul(num_interleaved, ConstMat(C1), + ConstMat(layer_weights->linear_w.data(), kFFHiddenDim), + layer_weights->linear_w.scale(), output_bias, + activations.env, + MutableMat(activations.ffw_out.All(), kModelDim)); } // `batch_idx` indicates which row of `x` to write to. @@ -594,8 +596,7 @@ template HWY_NOINLINE void TransformerLayer( const QueriesPos& queries_pos, size_t num_tokens, size_t layer, const CompressedLayer* layer_weights, Activations& activations, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - hwy::ThreadPool& pool) { + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_interleaved = num_tokens * queries_pos.size(); auto type = TConfig::kLayerConfig[layer]; @@ -607,7 +608,7 @@ HWY_NOINLINE void TransformerLayer( activations.pre_att_rms_out.All(), kModelDim); Attention(type, queries_pos, num_tokens, layer_of_type, activations, - layer_weights, div_seq_len, kv_caches, pool); + layer_weights, div_seq_len, kv_caches); PostNorm(num_interleaved, layer_weights->post_attention_norm_scale, activations.att_sums.All()); @@ -620,7 +621,7 @@ HWY_NOINLINE void TransformerLayer( layer_weights->pre_ffw_norm_scale.data_scale1(), activations.bf_pre_ffw_rms_out.All(), kModelDim); - FFW(activations, num_interleaved, layer_weights, pool); + FFW(activations, num_interleaved, layer_weights); PostNorm(num_interleaved, layer_weights->post_ffw_norm_scale, activations.ffw_out.All()); @@ -630,120 +631,71 @@ HWY_NOINLINE void TransformerLayer( /*is_attention=*/false); } -// Prefill and Transformer() advance positions in-place. +// Prefill() and Transformer() increment positions in-place. using QueriesMutablePos = hwy::Span; -// Batches are important for amortizing loading weights over multiple tokens. -// This is possible in prefill because we know all tokens beforehand, whereas -// decode depends on the previous output token. However, each prefill batch of a -// query requires that preceding batches already wrote to the KV cache, hence we -// sequentially loop over token batches. We can reduce the number of iterations -// by increasing the batch size, but this also increases arithmetic intensity, -// and so we are eventually compute-limited. The tensor parallelism (number of -// threads collaborating on MatMul) is also limited by the CPU topology: -// fork/join barriers are slow(er) when some threads reside in a different NUMA -// node. To allow more threads to help, we also support parallelizing over -// queries in case GenerateBatch was called. -// -// Thus we have two-level parallelism: -// - Outer: handles one 'qbatch' of entire queries. The set of outer workers -// includes the main thread because it is the one that calls `Prefill`, and is -// determined by the number of 'clusters' (shared L3 caches or sockets). -// - Inner: each `outer` worker passes `inner_pools_[outer]` to -// `TransformerLayer` for tensor-level parallelism, and processes -// `tbatch_size` tokens from a single query at a time. -// -// This class holds the thread pools and one activation per outer worker. It is -// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt -// to their num_queries. -class PrefillState { - public: - // `tbatch_size` is the number of tokens from one query to prefill at a time. - template - void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) { - PROFILER_ZONE("Init.Prefill"); - HWY_ASSERT(num_queries != 0); - HWY_ASSERT(activations_.empty()); // only call once. +// Populates KV cache for batches of tokens from one query at a time. +template +HWY_NOINLINE void Prefill( + const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query, + const QueriesMutablePos& queries_pos, const size_t query_idx_start, + const CompressedWeights& weights, Activations& activations, + const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, + const KVCaches& kv_caches) { + PROFILER_ZONE("Gen.Prefill"); + const size_t num_queries = queries_prompt.size(); + HWY_ASSERT(queries_pos.size() == num_queries); + HWY_ASSERT(kv_caches.size() == num_queries); - // Allocate one activation per query, not outer worker, because the common - // case is a single query. If we allocate the lesser of the two, it is - // unclear how to choose an unused activation in Prefill. - activations_.resize(num_queries); + // Batches are important for amortizing loading weights over multiple tokens. + // This is possible in prefill because we know all tokens beforehand, whereas + // decode depends on the previous output token. However, each prefill batch of + // a query requires that preceding batches already wrote to the KV cache, + // hence we sequentially loop over token batches. We can reduce the number of + // iterations by increasing the batch size, but this also increases arithmetic + // intensity, and so we are eventually compute-limited. We could devote some + // threads to parallelizing over queries, but for simplicity we assign them + // all to MatMul. + const size_t max_tbatch_size = activations.x.BatchSize(); - if (num_queries == 1) { - activations_[0].Allocate(tbatch_size); - } else { - // Allocating in parallel can save 30 ms. We might have more workers than - // queries/tasks, so do not check the `thread` argument. - pools.Outer().Run(0, num_queries, - [this, tbatch_size](uint64_t qi, size_t /*thread*/) { - activations_[qi].Allocate(tbatch_size); - }); - } + // For each query. `qi` is within the batch, not the global query index. + for (size_t qi = 0; qi < num_queries; ++qi) { + // Single query at a time, so pass slices of the spans because + // GemmaAttention will only access the first KV cache and position. + QueriesPos single_query_pos(&queries_pos[qi], 1); + KVCaches single_kv_cache(&kv_caches[qi], 1); + + // For each batch of tokens in the query: + for (size_t tbatch_start = 0; tbatch_start < prefill_per_query; + tbatch_start += max_tbatch_size) { + // Fill activations.x (much faster than TransformerLayer). + const size_t tbatch_size = + HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start); + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const int token = queries_prompt[qi][tbatch_start + ti]; + const size_t pos = queries_pos[qi] + ti; + EmbedToken(token, ti, pos, weights, activations.x); + } + + // Transformer with one batch of tokens from a single query. + for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { + const auto* layer_weights = weights.GetLayer(layer); + TransformerLayer(single_query_pos, tbatch_size, layer, + layer_weights, activations, div_seq_len, + single_kv_cache); + } + + // NOTE: we unconditionally call StreamToken, even if EOS. + for (size_t ti = 0; ti < tbatch_size; ++ti) { + const size_t pos = queries_pos[qi] + ti; + const int token = queries_prompt[qi][pos]; + runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); + } + + queries_pos[qi] += tbatch_size; + } // for tbatch_start } - - template - HWY_NOINLINE void Prefill(const QueriesPromptTokens& queries_prompt, - const size_t prefill_per_query, - const QueriesMutablePos& queries_pos, - const size_t query_idx_start, - const CompressedWeights& weights, - const RuntimeConfig& runtime_config, - const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches, PerClusterPools& pools) { - PROFILER_ZONE("Gen.Prefill"); - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(kv_caches.size() == num_queries); - const size_t max_tbatch_size = activations_[0].x.BatchSize(); - - // For each query (parallel): an outer worker processes all its tokens. - // `qi` is relative to the batch, not the global query index. - pools.Outer().Run( - 0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR { - Activations& activations = activations_[qi]; - hwy::ThreadPool& inner_pool = pools.Inner(qthread); - - // Single query at a time, so pass slices of the spans because - // GemmaAttention will only access the first KV cache and position. - KVCaches single_kv_cache(&kv_caches[qi], 1); - QueriesPos single_query_pos(&queries_pos[qi], 1); - - // For each batch of tokens in the query: - for (size_t tbatch_start = 0; tbatch_start < prefill_per_query; - tbatch_start += max_tbatch_size) { - // Fill activations.x (much faster than TransformerLayer). - const size_t tbatch_size = - HWY_MIN(max_tbatch_size, prefill_per_query - tbatch_start); - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const int token = queries_prompt[qi][tbatch_start + ti]; - const size_t pos = queries_pos[qi] + ti; - EmbedToken(token, ti, pos, weights, activations.x); - } - - // Transformer with one batch of tokens from a single query. - for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer(single_query_pos, tbatch_size, layer, - layer_weights, activations, div_seq_len, - single_kv_cache, inner_pool); - } - - // NOTE: we unconditionally call StreamToken, even if EOS. - for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; - const int token = queries_prompt[qi][pos]; - runtime_config.StreamToken(query_idx_start + qi, pos, token, - 0.0f); - } - - queries_pos[qi] += tbatch_size; - } // for tbatch_start - }); - } - - private: - std::vector activations_; // One per query, filled by Init. -}; +} // Generates one token for each query. `queries_token` is the previous token // from each query, and `queries_pos` are their position in the sequence. @@ -752,7 +704,7 @@ HWY_NOINLINE void Transformer( const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, const CompressedWeights& weights, Activations& activations, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, - hwy::ThreadPool& pool, const LayersOutputFunc& layers_output, + const LayersOutputFunc& layers_output, const ActivationsObserverFunc& activations_observer) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_queries = queries_token.size(); @@ -775,7 +727,7 @@ HWY_NOINLINE void Transformer( const CompressedLayer* layer_weights = weights.GetLayer(layer); TransformerLayer(queries_pos, /*num_tokens=*/1, layer, layer_weights, activations, div_seq_len, - kv_caches, pool); + kv_caches); if (activations_observer) { activations_observer(queries_pos, layer, activations); @@ -880,16 +832,12 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const RuntimeConfig& runtime_config, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in, const size_t query_idx_start, - const KVCaches& kv_caches, PerClusterPools& pools, - TimingInfo& timing_info) { + const KVCaches& kv_caches, TimingInfo& timing_info) { constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kVocabSize = TConfig::kVocabSize; const CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); - // TODO: remove once all parallel sections support hierarchical parallelism. - hwy::ThreadPool& pool = pools.Inner(0); - // Copy so we can increment without requiring users to pass in a mutable span. std::vector queries_pos_copy(queries_pos_in.cbegin(), queries_pos_in.cend()); @@ -930,19 +878,22 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Prefill stops before min_prompt_size - 1 because the last prompt token is // the first input token for generation. const size_t prefill_per_query = min_prompt_size - 1; - double prefill_start; - { - // TODO: move to Gemma, reuse across calls to Generate. - PrefillState prefill; - prefill.Init(num_queries, runtime_config.prefill_tbatch_size, - pools); - prefill_start = hwy::platform::Now(); - prefill.Prefill(queries_prompt, prefill_per_query, - queries_mutable_pos, query_idx_start, weights, - runtime_config, div_seq_len, kv_caches, pools); - timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); - // queries_pos are incremented by Prefill. + const double prefill_start = hwy::platform::Now(); + // If tbatch is larger than the qbatch we already have in `activations`, then + // allocate prefill_activations, otherwise reuse. + const bool use_prefill_activations = + runtime_config.prefill_tbatch_size > activations.x.BatchSize(); + Activations prefill_activations; + if (use_prefill_activations) { + prefill_activations.Allocate(runtime_config.prefill_tbatch_size, + activations.env.Pools()); } + Prefill(queries_prompt, prefill_per_query, queries_mutable_pos, + query_idx_start, weights, + use_prefill_activations ? prefill_activations : activations, + runtime_config, div_seq_len, kv_caches); + timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); + // queries_pos are incremented by Prefill. // Storage for the last generated token from each query, passed to the next // Transformer() call. @@ -962,18 +913,18 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Decode generates one token per query and increments queries_mutable_pos. Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, weights, activations, div_seq_len, - kv_caches, pool, runtime_config.layers_output, + kv_caches, runtime_config.layers_output, runtime_config.activations_observer); // queries_pos are incremented by Transformer. bool all_queries_eos = true; PROFILER_ZONE("Gen.Embedding"); // Compute logits from last layer activations. - MatMul_4x4( - num_queries, MakeMat(activations.x.All(), kModelDim), - MakeMat(weights.embedder_input_embedding.data(), kModelDim), + MatMul( + num_queries, ConstMat(activations.x.All(), kModelDim), + ConstMat(weights.embedder_input_embedding.data(), kModelDim), weights.embedder_input_embedding.scale(), /*add=*/nullptr, - MakeMat(activations.logits.All(), kVocabSize), pool); + activations.env, MutableMat(activations.logits.All(), kVocabSize)); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); @@ -1001,15 +952,16 @@ void GenerateSingleT(const ByteStorageT& weights_u8, constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; + // TODO: move into Gemma? Activations activations; - activations.Allocate(kNumQueries); + activations.Allocate(kNumQueries, pools); const QueriesPromptTokens prompt_span(&prompt, kNumQueries); QueriesPos pos_span(&pos, kNumQueries); const KVCaches kv_caches{&kv_cache, kNumQueries}; GenerateT(weights_u8, activations, runtime_config, prompt_span, - pos_span, qbatch_start, kv_caches, pools, timing_info); + pos_span, qbatch_start, kv_caches, timing_info); } template @@ -1026,7 +978,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8, (TConfig::kGriffinLayers > 0) ? 1 : runtime_config.decode_qbatch_size; Activations activations; - activations.Allocate(max_qbatch_size); + activations.Allocate(max_qbatch_size, pools); for (size_t qbatch_start = 0; qbatch_start < num_queries; qbatch_start += max_qbatch_size) { @@ -1038,7 +990,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8, QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT(weights_u8, activations, runtime_config, qbatch_prompts, - qbatch_pos, qbatch_start, qbatch_kv, pools, timing_info); + qbatch_pos, qbatch_start, qbatch_kv, timing_info); } } diff --git a/gemma/weights.h b/gemma/weights.h index d8d2ebb..09cbbcb 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -21,6 +21,7 @@ #include "compression/compress.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "util/allocator.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 5d1d435..d14dfd7 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -13,19 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ - #include -#include -#include -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/profiler.h" // temporarily disabled - -#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ +#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used +#include "ops/matmul.h" // IWYU pragma: export // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE) @@ -35,6 +26,8 @@ #define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE #endif +#include "hwy/highway.h" +// After highway.h #include "compression/compress-inl.h" #include "hwy/contrib/math/math-inl.h" @@ -43,355 +36,392 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -// A square kernel minimizes the ratio of loads to FMA. 4x 128-bit corresponds -// to one cache line. -constexpr size_t kRegRows = 4; +// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of +// loads, we reuse the same A row for several B columns, which are also loaded +// once for several rows of C. Thus we produce one 'tile' of C at a time of +// dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are +// limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4 +// enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume +// and verify that `C.cols % kRegCols == 0`. constexpr size_t kRegCols = 4; -// Initializes a reg-tile of C: if kAdd, `add[add_ofs + c]`; otherwise 0. -// `add` has no scale, and if `kAdd` is a row vector with A.cols entries, +// Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because +// we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile. +// In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus +// functions that load or store a tile are parameterized on `kNumRows`, which is +// generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). +constexpr size_t kRegRows = kRegCols; + +// NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are +// more efficient than f32 * f32 + f32 because they process twice as many lanes +// at a time. Any combination of A and B can be bf16: activations may already be +// bf16, and weights can be decompressed to bf16. +// +// The corresponding op is `ReordenWidenMulAccumulate`, and it is always +// supported, but only useful if it returns a single vector of pairwise sums +// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReordenWidenMulAccumulate` +// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep +// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be +// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B +// to bf16 if the native op is available. This will actually demote f32 +// activations to bf16. Otherwise, we decompress to f32 and use normal FMA. +using MulT = hwy::If; + +// Loads two vectors at a time with element type MulT from a row of transposed +// B. Called in a loop over col_ab. No bounds checking because `kRow` is +// actually from B columns, which we checked is a multiple of `kRegCols`. +template +class BRow { + static_assert(kRow < kRegRows); // which unrolled instance we are + using TraitsB = CompressTraits; + + public: + BRow(const Mat& B, size_t row_b) + : B_(B.ptr), B_ofs_(B.Row(row_b + kRow)) {} + + template > + HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const { + static_assert(hwy::IsSame, MulT>()); + TraitsB::Decompress2(d, B_, B_ofs_ + col_ab, b0, b1); + } + + private: + const MatTB* HWY_RESTRICT B_; + const size_t B_ofs_; +}; + +// Loads *two* row vectors from A via `Decompress2`, multiplies element-wise +// with `kRegRows` x 2 row vectors from transposed B, and adds them to +// `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of +// the terms of the dot products that make up the MatMul result at `r,c`. +// No-op for the bottom-most tile where kRow >= kNumRows. +// +// This approach is atypical because it requires a horizontal sum, for which we +// introduce a fast and new(?) vector-length agnostic 'transpose', see +// `AddHorizontalSums`. Most MatMul instead broadcast one element from A and +// multiply with one element from N columns in B to obtain N columns of C. +// This is a poor fit for our setting: +// - `CompressTraits` decompresses two vectors at a time; +// - B is column-major, so unit-stride SIMD loads return a column, not values +// from different columns, i.e. a row. +// Both could be fixed in a packing stage, which is not implemented yet, and +// might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is +// important for bf16 performance and incompatible with the conventional +// approach, because its pairwise adds would add together unrelated terms. +// By contrast, pairwise adds are fine when our C lanes are the terms of a +// single dot product, which can be reordered or pre-reduced. +template +class ALoadAccumulate { + static_assert(kRow < kRegRows); // which unrolled instance we are + using TraitsA = CompressTraits; + + public: + ALoadAccumulate(const Mat& A, size_t row_ac) + : A_(A.ptr), A_ofs_(A.Row(row_ac + kRow)) {} + + // First iteration, col_ab = 0: initialize C0..3 instead of updating them. + template , HWY_IF_F32_D(DM)> + HWY_INLINE void First(DM dm, // + const VM b00, const VM b01, const VM b10, const VM b11, + const VM b20, const VM b21, const VM b30, const VM b31, + VM& C0, VM& C1, VM& C2, VM& C3) const { + static_assert(kNumRows <= kRegRows); // How many rows actually present + if constexpr (kRow < kNumRows) { + VM a0, a1; + TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1); + + static_assert(kRegCols == 4); + C0 = hn::Mul(a0, b00); + C1 = hn::Mul(a0, b10); + C2 = hn::Mul(a0, b20); + C3 = hn::Mul(a0, b30); + C0 = hn::MulAdd(a1, b01, C0); + C1 = hn::MulAdd(a1, b11, C1); + C2 = hn::MulAdd(a1, b21, C2); + C3 = hn::MulAdd(a1, b31, C3); + } + } + + // Same as above, only called if MulT == BF16. + template , + HWY_IF_BF16_D(DM), class DF = hn::Repartition, + class VF = hn::Vec> + HWY_INLINE void First(DM dm, // + const VM b00, const VM b01, const VM b10, const VM b11, + const VM b20, const VM b21, const VM b30, const VM b31, + VF& C0, VF& C1, VF& C2, VF& C3) const { + static_assert(kNumRows <= kRegRows); // How many rows actually present + if constexpr (kRow < kNumRows) { + VM a0, a1; + TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1); + + const DF df; + VF unused_sum1 = hn::Zero(df); + + static_assert(kRegCols == 4); + C0 = hn::WidenMulPairwiseAdd(df, a0, b00); + C1 = hn::WidenMulPairwiseAdd(df, a0, b10); + C2 = hn::WidenMulPairwiseAdd(df, a0, b20); + C3 = hn::WidenMulPairwiseAdd(df, a0, b30); + C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); + + // Ensure sum1 was indeed unused. + HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + } + } + + // Non-first iteration: accumulate into C0..3. + template , HWY_IF_F32_D(DM)> + HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01, + const VM b10, const VM b11, const VM b20, const VM b21, + const VM b30, const VM b31, VM& C0, VM& C1, VM& C2, + VM& C3) const { + static_assert(kNumRows <= kRegRows); // How many rows actually present + HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration. + if constexpr (kRow < kNumRows) { + VM a0, a1; + TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); + + static_assert(kRegCols == 4); + C0 = hn::MulAdd(a0, b00, C0); + C1 = hn::MulAdd(a0, b10, C1); + C2 = hn::MulAdd(a0, b20, C2); + C3 = hn::MulAdd(a0, b30, C3); + C0 = hn::MulAdd(a1, b01, C0); + C1 = hn::MulAdd(a1, b11, C1); + C2 = hn::MulAdd(a1, b21, C2); + C3 = hn::MulAdd(a1, b31, C3); + } + } + + // Same as above, only called if MulT == BF16. + template , + HWY_IF_BF16_D(DM), class DF = hn::Repartition, + class VF = hn::Vec> + HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01, + const VM b10, const VM b11, const VM b20, const VM b21, + const VM b30, const VM b31, VF& C0, VF& C1, VF& C2, + VF& C3) const { + static_assert(kNumRows <= kRegRows); // How many rows actually present + HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration. + if constexpr (kRow < kNumRows) { + VM a0, a1; + TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); + + const DF df; + hn::Vec unused_sum1 = hn::Zero(df); + + static_assert(kRegCols == 4); + C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1); + C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); + C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); + C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); + C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); + + // Ensure sum1 was indeed unused. + HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); + } + } + + private: + const MatTA* HWY_RESTRICT A_; + const size_t A_ofs_; +}; // ALoadAccumulate + +// Sets a `kRegRows` x `kRegCols` tile of C to `add[add_ofs + c]` if kAdd, +// otherwise 0. +// `add` has no scale and is a row vector with A.cols entries if `kAdd`, // otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB, // hence we pass it as a separate argument. template HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs, float* HWY_RESTRICT pos_c, size_t stride_c) { + const hn::FixedTag d4; for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) { - for (size_t c = 0; c < kRegCols; ++c) { - if constexpr (kAdd) { - pos_c[r * stride_c + c] = add[add_ofs + c]; - } else { - pos_c[r * stride_c + c] = 0.0f; - } + if constexpr (kAdd) { + hn::StoreU(hn::LoadU(d4, add + add_ofs), d4, pos_c + r * stride_c); + } else { + hn::StoreU(hn::Zero(d4), d4, pos_c + r * stride_c); } } } -// c## are partial sums of the products of A and B; their horizontal sums are -// the final matmul result, stored in C, which is always f32. -template > -HWY_INLINE void AddHorizontalSums(DF df, float scale, // - VF c00, VF c01, VF c02, VF c03, // - VF c10, VF c11, VF c12, VF c13, // - VF c20, VF c21, VF c22, VF c23, // - VF c30, VF c31, VF c32, VF c33, // - float* HWY_RESTRICT tile_c, size_t stride_c) { - // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. - // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in - // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is - // expensive, but only a fraction of the A.cols/N FMAs. - // TODO: 4x4 transpose, then 128-bit vector FMA? - tile_c[stride_c * 0 + 0] += scale * hn::ReduceSum(df, c00); - tile_c[stride_c * 0 + 1] += scale * hn::ReduceSum(df, c01); - tile_c[stride_c * 0 + 2] += scale * hn::ReduceSum(df, c02); - tile_c[stride_c * 0 + 3] += scale * hn::ReduceSum(df, c03); - if (kNumRows == 1) return; - - tile_c[stride_c * 1 + 0] += scale * hn::ReduceSum(df, c10); - tile_c[stride_c * 1 + 1] += scale * hn::ReduceSum(df, c11); - tile_c[stride_c * 1 + 2] += scale * hn::ReduceSum(df, c12); - tile_c[stride_c * 1 + 3] += scale * hn::ReduceSum(df, c13); - if (kNumRows == 2) return; - - tile_c[stride_c * 2 + 0] += scale * hn::ReduceSum(df, c20); - tile_c[stride_c * 2 + 1] += scale * hn::ReduceSum(df, c21); - tile_c[stride_c * 2 + 2] += scale * hn::ReduceSum(df, c22); - tile_c[stride_c * 2 + 3] += scale * hn::ReduceSum(df, c23); - if (kNumRows == 3) return; - - tile_c[stride_c * 3 + 0] += scale * hn::ReduceSum(df, c30); - tile_c[stride_c * 3 + 1] += scale * hn::ReduceSum(df, c31); - tile_c[stride_c * 3 + 2] += scale * hn::ReduceSum(df, c32); - tile_c[stride_c * 3 + 3] += scale * hn::ReduceSum(df, c33); -} - -// Wrapper to simplify call sites. T can be const or non-const. -template -struct Mat { - bool NotEmpty() const { - return ptr != nullptr && cols != 0 && stride >= cols; +// Accumulates into a tile of C. +template +class AddHorizontalSums { + // These helper functions hoist if() out of the main code below. They have no + // effect if kRow >= kNumRows. + template > + static void MaybeStoreInterleaved4(DF df, size_t N, VF Cr0, VF Cr1, VF Cr2, + VF Cr3, float* HWY_RESTRICT buf) { + if constexpr (kRow < kNumRows) { + hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, df, buf + 4 * kRow * N); + } } - size_t Row(size_t r) const { return ofs + stride * r; } - T* HWY_RESTRICT ptr; - size_t cols; + // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. + template > + static V4 MaybeLoad(D4 df, size_t N, const float* HWY_RESTRICT buf) { + if constexpr (kRow < kNumRows) { + return hn::Load(df, buf + 4 * kRow * N); + } else { + return hn::Zero(df); + } + } - // elements between rows, which is typically the same as `cols`. - size_t stride; + template > + static V4 MaybeAdd(D4 df, size_t N, V4 sum, const float* HWY_RESTRICT buf) { + if constexpr (kRow < kNumRows) { + return hn::Add(sum, hn::Load(df, buf + 4 * kRow * N)); + } else { + return sum; + } + } - // Offset to add to `ptr`; separate because T=NuqStream does not support - // pointer arithmetic. - size_t ofs; + template > + static void MaybeMulAdd(D4 df, V4 sum, V4 scale, float* HWY_RESTRICT tile_c, + const size_t stride_c) { + if constexpr (kRow < kNumRows) { + const V4 prev_c = hn::LoadU(df, tile_c + kRow * stride_c); + hn::StoreU(hn::MulAdd(sum, scale, prev_c), df, tile_c + kRow * stride_c); + } + } + + public: + // Adds the contribution from `Crc` accumulators to the 4x4 tile of C whose + // top left is `tile_c`, after multiplying by `scale`, which is the product of + // the scales of A and B. C is always f32 to ensure sufficient precision. + // + // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a + // B column vector indexed by `c`. Their elements are thus a subset of the + // terms of the dot product constituting the final `C[r, c]` result. Thus we + // compute the horizontal sums of each `Crc`. The elements may be permuted + // because we multiply bf16 via `ReorderWidenMulAccumulate`, but this does + // not change their horizontal sum. `buf` is thread-local space for 16 `VF`. + template > + HWY_INLINE void operator()(DF df, float scale, // + VF C00, VF C01, VF C02, VF C03, // + VF C10, VF C11, VF C12, VF C13, // + VF C20, VF C21, VF C22, VF C23, // + VF C30, VF C31, VF C32, VF C33, // + float* HWY_RESTRICT buf, + float* HWY_RESTRICT tile_c, + size_t stride_c) const { + const size_t N = hn::Lanes(df); + // Horizontal reductions (`ReduceSum`) are rather expensive, entailing + // log(N) operations for vectors of length N. Because kRegCols == 4, we can + // instead use `StoreInterleaved4` for a vector length-agnostic 'transpose': + // `buf[0, 4 * N)` holds C00[0], C01[0], C02[0], C03[0], + // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], C03[N-1]. + MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf); + MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf); + MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf); + MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf); + // Adding N consecutive V4 yields four horizontal sums of Cr0, Cr1, Cr2, Cr3 + // in the elements of one V4. We have four independent rows `r`, hence the + // code is effectively unrolled, which increases throughput. + const hn::FixedTag d4; + using V4 = hn::Vec; + V4 sum0 = MaybeLoad<0>(d4, N, buf); + V4 sum1 = MaybeLoad<1>(d4, N, buf); + V4 sum2 = MaybeLoad<2>(d4, N, buf); + V4 sum3 = MaybeLoad<3>(d4, N, buf); + + for (size_t i = 1; i < N; ++i) { + sum0 = MaybeAdd<0>(d4, N, sum0, buf + 4 * i); + sum1 = MaybeAdd<1>(d4, N, sum1, buf + 4 * i); + sum2 = MaybeAdd<2>(d4, N, sum2, buf + 4 * i); + sum3 = MaybeAdd<3>(d4, N, sum3, buf + 4 * i); + } + // Scale, then store to four elements per row of `tile_c`. + const V4 vscale = hn::Set(d4, scale); + MaybeMulAdd<0>(d4, sum0, vscale, tile_c, stride_c); + MaybeMulAdd<1>(d4, sum1, vscale, tile_c, stride_c); + MaybeMulAdd<2>(d4, sum2, vscale, tile_c, stride_c); + MaybeMulAdd<3>(d4, sum3, vscale, tile_c, stride_c); + } }; -template -Mat MakeMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride, - size_t ofs = 0) { - return Mat{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs}; -} - -template -Mat MakeMat(T* HWY_RESTRICT ptr, size_t cols) { - return MakeMat(ptr, cols, cols); -} - -// Inner loop of the kernel, called once per kRegRows. c[r] += a[c] * b[r,c]. -// The col_ab loop is unrolled 2x, so we have a0/a1 and b00/b01 etc. -template -HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, - const VF& b01, const VF& b10, const VF& b11, - const VF& b20, const VF& b21, const VF& b30, - const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) { - c0 = hn::MulAdd(a0, b00, c0); - c1 = hn::MulAdd(a0, b10, c1); - c2 = hn::MulAdd(a0, b20, c2); - c3 = hn::MulAdd(a0, b30, c3); - c0 = hn::MulAdd(a1, b01, c0); - c1 = hn::MulAdd(a1, b11, c1); - c2 = hn::MulAdd(a1, b21, c2); - c3 = hn::MulAdd(a1, b31, c3); -} - -// Special case for the first iteration: c## are zero, so skip the first add. -template -HWY_INLINE void FirstTileRow(const VF& a0, const VF& a1, const VF& b00, - const VF& b01, const VF& b10, const VF& b11, - const VF& b20, const VF& b21, const VF& b30, - const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) { - c0 = hn::Mul(a0, b00); - c1 = hn::Mul(a0, b10); - c2 = hn::Mul(a0, b20); - c3 = hn::Mul(a0, b30); - c0 = hn::MulAdd(a1, b01, c0); - c1 = hn::MulAdd(a1, b11, c1); - c2 = hn::MulAdd(a1, b21, c2); - c3 = hn::MulAdd(a1, b31, c3); -} - -#undef GEMMA_NATIVE_BF16 -#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \ - defined(HWY_TARGET_TOGGLE)) -#define GEMMA_NATIVE_BF16 1 -#else -#define GEMMA_NATIVE_BF16 0 -#endif - -#if GEMMA_NATIVE_BF16 - -// Specializations for f32 += bf16 * bf16 that avoid promoting to f32. - -// Inner loop as above, but not unrolled. c[r] += a * b[r]. -template , - class VBF16 = hn::Vec>> -HWY_INLINE void UpdateTileRow(DF df, const VBF16& a, const VBF16& b0, - const VBF16& b1, const VBF16& b2, const VBF16& b3, - VF& c0, VF& c1, VF& c2, VF& c3) { - DF df; - VF unused_sum1 = hn::Zero(df); - c0 = hn::ReorderWidenMulAccumulate(df, a, b0, c0, unused_sum1); - c1 = hn::ReorderWidenMulAccumulate(df, a, b1, c1, unused_sum1); - c2 = hn::ReorderWidenMulAccumulate(df, a, b2, c2, unused_sum1); - c3 = hn::ReorderWidenMulAccumulate(df, a, b3, c3, unused_sum1); - - // Ensure sum1 was indeed unused. - HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); -} - -// Special case for the first iteration: c## are zero, so skip the first add. -template , - class VBF16 = hn::Vec>> -HWY_INLINE void FirstTileRow(DF df, const VBF16& a, const VBF16& b0, - const VBF16& b1, const VBF16& b2, const VBF16& b3, - VF& c0, VF& c1, VF& c2, VF& c3) { - c0 = hn::WidenMulPairwiseAdd(df, a, b0); - c1 = hn::WidenMulPairwiseAdd(df, a, b1); - c2 = hn::WidenMulPairwiseAdd(df, a, b2); - c3 = hn::WidenMulPairwiseAdd(df, a, b3); -} - -template -HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, - const size_t row_ac, const size_t row_b_col_c, - const float scale, const float* HWY_RESTRICT add, - const Mat& C) { - const hn::ScalableTag df; - using VF = hn::Vec; - // ReorderWidenMulAccumulate does not use its sum1 arg and we can use full - // bf16 vectors. - const hn::Repartition d; - const size_t N = Lanes(d); - using V = hn::Vec; - V b0, b1, b2, b3; // one from each row - VF c00, c01, c02, c03; - VF c10, c11, c12, c13; - VF c20, c21, c22, c23; - VF c30, c31, c32, c33; - - const BF16* HWY_RESTRICT A_tile = A.ptr + A.Row(row_ac); - const BF16* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c); - float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; - InitC(add, row_b_col_c, C_tile, C.stride); - - size_t col_ab = 0; - - // First iteration initializes the c## vectors. - { - b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab); - b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab); - b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab); - b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab); - - { - const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab); - FirstTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03); - } - if constexpr (kNumRows > 1) { - const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab); - FirstTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13); - } - if constexpr (kNumRows > 2) { - const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab); - FirstTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23); - } - if constexpr (kNumRows == 3) { - const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab); - FirstTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33); - } - } - - // Loop over columns of A and columns of the transposed B, in steps of N. - // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (col_ab += N; col_ab < A.cols; col_ab += N) { - b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab); - b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab); - b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab); - b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab); - - { - const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab); - UpdateTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03); - } - if constexpr (kNumRows > 1) { - const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab); - UpdateTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13); - } - if constexpr (kNumRows > 2) { - const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab); - UpdateTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23); - } - if constexpr (kNumRows == 3) { - const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab); - UpdateTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33); - } - } - - AddHorizontalSums(df, scale, c00, c01, c02, c03, c10, c11, c12, c13, - c20, c21, c22, c23, c30, c31, c32, c33, C_tile, - C.stride); -} - -#endif // GEMMA_NATIVE_BF16 - -// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a -// finished tile of `C`. -// General case: uses CompressTraits to load from A and B. +// Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a +// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c). +// TODO: loop over sections instead of full rows and accumulate into `tile_c`. template -HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, +HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, const size_t row_ac, const size_t row_b_col_c, const float scale, const float* HWY_RESTRICT add, - const Mat& C) { - using TraitsA = CompressTraits>; - using TraitsB = CompressTraits>; + float* HWY_RESTRICT buf, const Mat& C) { + // For 'decompressing' A and B into BF16 or float. + const hn::ScalableTag dm; + using VM = hn::Vec; + const size_t NM = hn::Lanes(dm); - const hn::ScalableTag d32; - const size_t N = hn::Lanes(d32); - using V = hn::Vec; - V b00, b01, b10, b11, b20, b21, b30, b31; // two from each row - V c00, c01, c02, c03; - V c10, c11, c12, c13; - V c20, c21, c22, c23; - V c30, c31, c32, c33; + static_assert(kRegRows == 4); + const BRow<0, MatTB> b_row0(B, row_b_col_c); + const BRow<1, MatTB> b_row1(B, row_b_col_c); + const BRow<2, MatTB> b_row2(B, row_b_col_c); + const BRow<3, MatTB> b_row3(B, row_b_col_c); - const size_t A_ofs = A.Row(row_ac); - const size_t B_ofs = B.Row(row_b_col_c); + const ALoadAccumulate<0, MatTA> a_row0(A, row_ac); + const ALoadAccumulate<1, MatTA> a_row1(A, row_ac); + const ALoadAccumulate<2, MatTA> a_row2(A, row_ac); + const ALoadAccumulate<3, MatTA> a_row3(A, row_ac); + + const hn::Repartition df; + using VF = hn::Vec; + VF C00, C01, C02, C03; + VF C10, C11, C12, C13; + VF C20, C21, C22, C23; + VF C30, C31, C32, C33; + + { // First iteration initializes the `Crc` vectors. + VM b00, b01, b10, b11, b20, b21, b30, b31; + b_row0.Load2(dm, /*col_ab=*/0, b00, b01); + b_row1.Load2(dm, /*col_ab=*/0, b10, b11); + b_row2.Load2(dm, /*col_ab=*/0, b20, b21); + b_row3.Load2(dm, /*col_ab=*/0, b30, b31); + + a_row0.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + C00, C01, C02, C03); + a_row1.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + C10, C11, C12, C13); + a_row2.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + C20, C21, C22, C23); + a_row3.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, + C30, C31, C32, C33); + } + + // `2 * NM` per iteration because `Load2` returns two vectors. + HWY_UNROLL(1) + for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) { + VM b00, b01, b10, b11, b20, b21, b30, b31; + b_row0.Load2(dm, col_ab, b00, b01); + b_row1.Load2(dm, col_ab, b10, b11); + b_row2.Load2(dm, col_ab, b20, b21); + b_row3.Load2(dm, col_ab, b30, b31); + + a_row0.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + b30, b31, C00, C01, C02, C03); + a_row1.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + b30, b31, C10, C11, C12, C13); + a_row2.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + b30, b31, C20, C21, C22, C23); + a_row3.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, + b30, b31, C30, C31, C32, C33); + } + + // TODO: hoist into outer loop. float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; InitC(add, row_b_col_c, C_tile, C.stride); - // Loop over columns of A and columns of the transposed B, in steps of 2*N - // (since we are decoding consecutive bytes at each iteration). - // Top-left of tile is (row_ac, col_ab) for A, and (row_b_col_c, - // col_ab) for B. First iteration initializes the c## vectors. - size_t col_ab = 0; - - { - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01); - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11); - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21); - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31); - - { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1); - FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, - c02, c03); - } - if constexpr (kNumRows > 1) { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1); - FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, - c12, c13); - } - if constexpr (kNumRows > 2) { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1); - FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, - c22, c23); - } - if constexpr (kNumRows > 3) { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1); - FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, - c32, c33); - } - } - - // Main loop: accumulates into the c## vectors. - HWY_UNROLL(1) - for (col_ab += 2 * N; col_ab <= A.cols - 2 * N; col_ab += 2 * N) { - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01); - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11); - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21); - TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31); - - { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1); - UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, - c02, c03); - } - if constexpr (kNumRows > 1) { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1); - UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, - c12, c13); - } - if constexpr (kNumRows > 2) { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1); - UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, - c22, c23); - } - if constexpr (kNumRows > 3) { - V a0, a1; - TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1); - UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, - c32, c33); - } - } - - AddHorizontalSums(d32, scale, c00, c01, c02, c03, c10, c11, c12, - c13, c20, c21, c22, c23, c30, c31, c32, c33, - C_tile, C.stride); + AddHorizontalSums()(df, scale, C00, C01, C02, C03, C10, C11, C12, + C13, C20, C21, C22, C23, C30, C31, C32, C33, + buf, C_tile, C.stride); } // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. @@ -402,28 +432,28 @@ HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, // // `scale` allows expanding the smaller range of `SfpStream` to the original // values. When `A` and/or `B` are from CompressedArray, `scale` should be the -// product of their `.scale()` values. +// product of their `.scale()` values, otherwise 1.0f. // // If `kAdd` is true, the row-vector `add` is added to each row of `C`, // otherwise `add` is ignored and can be nullptr. A scale for `add` is not // supported, so make sure its scale is 1. // // `C` is a row-major matrix of size `(batch_size, C.cols)`. -// Writes 4x4 tiles of C in parallel using a work-stealing thread pool. -// Typically batch_size is 1..512, A.cols and C.cols are 3k or 24k. +// +// Updates 4x4 tiles of C in parallel using a work-stealing thread pool. +// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k. +// Must not be called concurrently with the same `env`. template -HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat& A, - const Mat& B, const float scale, - const float* HWY_RESTRICT add, const Mat& C, - hwy::ThreadPool& pool) { +HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, + const Mat& B, const float scale, + const float* HWY_RESTRICT add, MatMulEnv& env, + const Mat& C) { // PROFILER_ZONE("Matmul"); HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); HWY_DASSERT(A.cols == B.cols); - // Use float instead of MatTA/MatTB because we decompress to float here. - const size_t N = hn::Lanes(hn::ScalableTag()); - (void)N; - HWY_DASSERT(A.cols % (N * 2) == 0); // For Decompress2. + // Must be a multiple of two vectors because we Decompress2. + HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0); HWY_DASSERT(C.cols % kRegCols == 0); // We currently write C directly, which touches more memory than fits in L3. @@ -431,30 +461,32 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat& A, const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); const size_t tilesX = C.cols / kRegCols; - pool.Run(0, tilesX * tilesY, - [&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR { - const size_t tx = idx_tile % tilesX; - const size_t ty = idx_tile / tilesX; - const size_t row_ac = ty * kRegRows; - const size_t row_b_col_c = tx * kRegCols; - // How many rows of C are left to compute. If more than 4, this - // tile still only computes 4 rows. - const size_t num_rows = batch_size - row_ac; - HWY_DASSERT(num_rows != 0); - switch (num_rows) { - case 1: - MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); - break; - case 2: - MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); - break; - case 3: - MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); - break; - default: - MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C); - } - }); + env.Pool().Run( + 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { + // TODO: when using PerClusterPool, compute lp from outer and inner. + float* HWY_RESTRICT buf = env.Buf(thread); + const size_t tx = idx_tile % tilesX; + const size_t ty = idx_tile / tilesX; + const size_t row_ac = ty * kRegRows; + const size_t row_b_col_c = tx * kRegCols; + // How many rows of C are left to compute. If more than 4, this + // tile still only computes 4 rows. + const size_t num_rows = batch_size - row_ac; + HWY_DASSERT(num_rows != 0); + switch (num_rows) { + case 1: + MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); + break; + case 2: + MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); + break; + case 3: + MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); + break; + default: + MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); + } + }); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/matmul.h b/ops/matmul.h new file mode 100644 index 0000000..ecc72b1 --- /dev/null +++ b/ops/matmul.h @@ -0,0 +1,97 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ + +#include + +#include "util/allocator.h" // RowVectorBatch +#include "util/threading.h" // PerClusterPools +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/per_target.h" + +namespace gcpp { + +// Bundles ptr/size/stride arguments to simplify MatMul call sites. T can be +// const or non-const. Create via ConstMat/MutableMat. +template +struct Mat { + bool NotEmpty() const { + return ptr != nullptr && cols != 0 && stride >= cols; + } + size_t Row(size_t r) const { return ofs + stride * r; } + + T* HWY_RESTRICT ptr; + size_t cols; + + // elements between rows, which is typically the same as `cols`. + size_t stride; + + // Offset to add to `ptr`; separate because T=NuqStream does not support + // pointer arithmetic. + size_t ofs; +}; + +template +Mat MutableMat(T* HWY_RESTRICT ptr, size_t cols, size_t stride, + size_t ofs = 0) { + return Mat{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs}; +} + +template +Mat ConstMat(const T* HWY_RESTRICT ptr, size_t cols, size_t stride, + size_t ofs = 0) { + return Mat{.ptr = ptr, .cols = cols, .stride = stride, .ofs = ofs}; +} + +template +Mat ConstMat(Mat mat) { + return ConstMat(mat.ptr, mat.cols, mat.stride, mat.ofs); +} + +template +Mat MutableMat(T* HWY_RESTRICT ptr, size_t cols) { + return MutableMat(ptr, cols, cols); +} + +template +Mat ConstMat(const T* HWY_RESTRICT ptr, size_t cols) { + return ConstMat(ptr, cols, cols); +} + +// Allocations and threads, shared across MatMul calls. +class MatMulEnv { + public: + MatMulEnv() : pools_(nullptr) {} + explicit MatMulEnv(PerClusterPools& pools) : pools_(&pools) { + const size_t num_lp = pools.NumLP(); + const size_t NF = hwy::VectorBytes() / sizeof(float); + buf_ = RowVectorBatch(num_lp, 16 * NF); + } + + float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); } + PerClusterPools& Pools() const { return *pools_; } + hwy::ThreadPool& Pool() const { return pools_->Inner(0); } + + private: + RowVectorBatch buf_; + PerClusterPools* pools_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index b34321e..15289fa 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -24,6 +24,7 @@ #include #include "compression/compress.h" +#include "util/threading.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -35,9 +36,10 @@ // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" -#include "hwy/tests/test_util-inl.h" // After highway.h +#include "compression/compress-inl.h" #include "ops/matmul-inl.h" +#include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -149,7 +151,7 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) * MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab); const double epsilon = hwy::ConvertScalarTo(hwy::Epsilon()); - const double tolerance = 50.0 * norm * epsilon; + const double tolerance = 200.0 * norm * epsilon; for (size_t idx = 0; idx < num_c; idx++) { const double expected_value = expected_c[idx]; @@ -157,8 +159,10 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, if (!(expected_value - tolerance <= actual_value && actual_value <= expected_value + tolerance)) { - fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx, - expected_value, idx, actual_value); + fprintf( + stderr, + "expected[%lu]: %f, actual[%lu]: %f, norm %f eps %E tolerance %f\n", + idx, expected_value, idx, actual_value, norm, epsilon, tolerance); HWY_ASSERT(0); } } @@ -202,14 +206,15 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, double elapsed) { - // 2 because of FMA. - fprintf(stderr, "%s: %f seconds, %f GFLOPS.\n", algo, elapsed, - 2E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed); + // 2x because of FMA. + fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo, + elapsed, 2 * 1E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed); } template -void TestMatMul(hwy::ThreadPool& pool) { +void TestMatMul(MatMulEnv& env) { + hwy::ThreadPool& pool = env.Pool(); using TraitsA = CompressTraits; using TraitsB = CompressTraits; const bool want_bench = kColsBC > 2000; // avoid spam for small matrices @@ -247,14 +252,14 @@ void TestMatMul(hwy::ThreadPool& pool) { double min_elapsed = hwy::HighestValue(); for (int rep = 0; rep < (want_bench ? 3 : 1); ++rep) { const double start_tiled = hwy::platform::Now(); - MatMul_4x4(kRowsAC, MakeMat(a->data(), kColsARowsB), - MakeMat(b_trans->data(), kColsARowsB), scale, - kAdd ? add->data_scale1() : nullptr, - MakeMat(c.get(), kColsBC), pool); + MatMul(kRowsAC, ConstMat(a->data(), kColsARowsB), + ConstMat(b_trans->data(), kColsARowsB), scale, + kAdd ? add->data_scale1() : nullptr, env, + MutableMat(c.get(), kColsBC)); min_elapsed = HWY_MIN(min_elapsed, hwy::platform::Now() - start_tiled); } if (want_bench) { - PrintSpeed("MatMul_4x4", kRowsAC, kColsARowsB, kColsBC, min_elapsed); + PrintSpeed("MatMul", kRowsAC, kColsARowsB, kColsBC, min_elapsed); } AssertClose(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), @@ -268,53 +273,56 @@ void TestAllMatMul() { return; } - hwy::ThreadPool pool(4); + PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1); + MatMulEnv env(pools); using F32 = float; using SFP = SfpStream; // large-scale test - TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(pool); - TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(pool); + TestMatMul<64, 24576, 3072, /*kAdd=*/false, BF16, SFP>(env); + TestMatMul<64, 3072, 24576, /*kAdd=*/false, BF16, SFP>(env); + TestMatMul<64, 24576, 3072, /*kAdd=*/false, F32, SFP>(env); + TestMatMul<64, 3072, 24576, /*kAdd=*/false, F32, SFP>(env); // medium-sized square test - TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(pool); - TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(pool); - TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(pool); - TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(pool); - TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(pool); - TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(pool); + TestMatMul<512, 512, 512, /*kAdd=*/false, F32>(env); + TestMatMul<512, 512, 512, /*kAdd=*/true, BF16>(env); + TestMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(env); + TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(env); + TestMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(env); + TestMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(env); // minimal non-square test. kColsARowsB must be at least 2 vectors. - TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(pool); - TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(pool); - TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(pool); - TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(pool); - TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(pool); - TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(pool); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(pool); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(pool); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(pool); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(pool); - TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(pool); - TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(pool); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(pool); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(pool); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(pool); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(pool); - TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(pool); - TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(pool); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(pool); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(pool); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(pool); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(pool); - TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(pool); - TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(pool); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(pool); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(pool); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(pool); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(pool); - TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(pool); - TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(pool); + TestMatMul<35, 128, 32, /*kAdd=*/false, F32>(env); + TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env); + TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env); + TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env); + TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env); + TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env); + TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env); + TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env); + TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env); + TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env); + TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env); + TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env); + TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env); + TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env); + TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env); + TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env); + TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env); + TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env); + TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env); + TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env); + TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env); + TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env); + TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env); + TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env); + TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env); + TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env); + TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env); + TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env); + TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env); + TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/util/allocator.h b/util/allocator.h new file mode 100644 index 0000000..f459fe7 --- /dev/null +++ b/util/allocator.h @@ -0,0 +1,75 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ + +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +namespace gcpp { + +using ByteStorageT = hwy::AlignedFreeUniquePtr; + +template +ByteStorageT AllocateSizeof() { + return hwy::AllocateAligned(sizeof(T)); +} + +// Owns dynamically-allocated aligned memory for a batch of row vectors. +// This can be seen as a (batch_size x len) matrix. +template +class RowVectorBatch { + public: + // Default ctor for Activations ctor. + RowVectorBatch() : batch_size_(0), len_(0) {} + // Main ctor, called from Activations::Allocate. + RowVectorBatch(size_t batch_size, size_t len) + : batch_size_(batch_size), len_(len) { + mem_ = hwy::AllocateAligned(batch_size * len); + } + + // Move-only + RowVectorBatch(RowVectorBatch&) noexcept = delete; + RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; + RowVectorBatch(RowVectorBatch&&) noexcept = default; + RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; + + size_t BatchSize() const { return batch_size_; } + size_t Len() const { return len_; } + + // Returns the given row vector of length `Len()`. + T* Batch(size_t batch_idx) { + HWY_DASSERT(batch_idx < batch_size_); + return mem_.get() + batch_idx * len_; + } + + // For MatMul or other operations that process the entire batch at once. + T* All() { return mem_.get(); } + const T* Const() const { return mem_.get(); } + size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); } + + private: + hwy::AlignedFreeUniquePtr mem_; + size_t batch_size_; // rows in the matrix + size_t len_; // columns in the matrix = vector length +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ diff --git a/util/threading.h b/util/threading.h index 4a995ff..bc2d63e 100644 --- a/util/threading.h +++ b/util/threading.h @@ -197,6 +197,11 @@ class PerClusterPools { return *inner_pools_[outer]; } + // Returns number of logical processors, for allocating per-thread buffers. + size_t NumLP() const { + return outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers(); + } + private: bool have_threading_support_; CoreBitSets cores_per_cluster_;