From d538a6d6c6b93d0bc09b140d4d057f52d949ebf8 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 13 May 2025 01:05:42 -0700 Subject: [PATCH] Cleanup: remove unused kCyclic, remove 2 suffix Also remove now unused allocator arg and fix warnings (cast, struct/class mismatch) PiperOrigin-RevId: 758098495 --- gemma/configs.cc | 2 - gemma/gemma-inl.h | 44 +++++++------------- gemma/gemma_args.h | 2 +- gemma/kv_cache.h | 1 - io/io.cc | 2 +- io/io.h | 2 +- io/io_win.cc | 2 +- ops/bench_matmul.cc | 2 +- ops/dot_test.cc | 2 +- ops/matmul-inl.h | 31 ++++++-------- ops/matmul.h | 10 ++--- ops/matmul_test.cc | 5 +-- ops/ops_test.cc | 2 +- util/allocator.cc | 31 +++++++------- util/allocator.h | 47 ++++++++++----------- util/mat.cc | 41 ++++++++----------- util/mat.h | 99 +++++++++------------------------------------ util/threading.h | 2 +- 18 files changed, 114 insertions(+), 213 deletions(-) diff --git a/gemma/configs.cc b/gemma/configs.cc index 291390c..ed42062 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -42,8 +42,6 @@ static ModelConfig ConfigNoSSM() { return config; } -static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); } - static ModelConfig ConfigBaseGemmaV2() { ModelConfig config = ConfigNoSSM(); config.att_cap = 50.0f; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 29c898b..f658099 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -31,7 +31,6 @@ #include "gemma/kv_cache.h" #include "gemma/weights.h" #include "paligemma/image.h" -#include "util/allocator.h" #include "util/mat.h" #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span @@ -260,8 +259,7 @@ class GemmaAttention { const size_t w1_rows = heads * layer_config_.QStride(); w_q1.ShrinkRows(w1_rows); MatMul(activations_.pre_att_rms_out, w_q1, - /*add=*/nullptr, *activations_.env, - RowPtrFromMat(allocator_, activations_.q)); + /*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. @@ -284,7 +282,7 @@ class GemmaAttention { const size_t kv_ofs = queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; - RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols); + RowPtrF kv_rows(kv, w_rows_kv_cols); kv_rows.SetStride(cache_pos_size_); MatMul(activations_.pre_att_rms_out, w_q2, /*add=*/nullptr, *activations_.env, kv_rows); @@ -490,7 +488,7 @@ class GemmaAttention { ? layer_weights_.attention_output_biases.PackedScale1() : nullptr; MatMul(activations_.att_out, layer_weights_.att_weights, add, - *activations_.env, RowPtrFromMat(allocator_, activations_.att_sums)); + *activations_.env, RowPtrFromMat(activations_.att_sums)); } public: @@ -556,7 +554,6 @@ class GemmaAttention { layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), - allocator_(ctx.allocator), pool_(ctx.pools.Pool(0)) { HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, @@ -586,7 +583,6 @@ class GemmaAttention { const LayerWeightsPtrs& layer_weights_; const hwy::Divisor& div_seq_len_; const KVCaches& kv_caches_; - const Allocator& allocator_; hwy::ThreadPool& pool_; }; @@ -631,7 +627,7 @@ class VitAttention { HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, - RowPtrFromMat(allocator_, qkv)); + RowPtrFromMat(qkv)); } // TODO(philculliton): transition fully to MatMul. @@ -671,7 +667,7 @@ class VitAttention { }); // this produces C, a (num_tokens_, seq_len) matrix of dot products - MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C)); + MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(C)); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); @@ -737,7 +733,7 @@ class VitAttention { // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums); + auto att_sums = RowPtrFromMat(activations_.att_sums); MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias, *activations_.env, att_sums); } @@ -750,7 +746,6 @@ class VitAttention { activations_(activations), layer_weights_(*layer_weights), layer_config_(layer_weights->layer_config), - allocator_(activations.env->ctx.allocator), pool_(activations.env->ctx.pools.Pool(0)) {} HWY_INLINE void operator()() { @@ -769,7 +764,6 @@ class VitAttention { Activations& activations_; const LayerWeightsPtrs& layer_weights_; const LayerConfig& layer_config_; - const Allocator& allocator_; hwy::ThreadPool& pool_; }; @@ -832,10 +826,9 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; // Define slightly more readable names for the weights and activations. - const Allocator& allocator = activations.env->ctx.allocator; - auto hidden_activations = RowPtrFromMat(allocator, activations.C1); - auto multiplier = RowPtrFromMat(allocator, activations.C2); - auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out); + auto hidden_activations = RowPtrFromMat(activations.C1); + auto multiplier = RowPtrFromMat(activations.C2); + auto ffw_out = RowPtrFromMat(activations.ffw_out); using WeightT = typename decltype(layer_weights->gating_einsum_w)::T; @@ -881,22 +874,16 @@ HWY_NOINLINE void FFWVit(Activations& activations, const float* output_bias = add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; - // Define slightly more readable names for the weights and activations. - const Allocator& allocator = activations.env->ctx.allocator; - auto hidden_activations = RowPtrFromMat(allocator, activations.C1); - auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out); - // Compute the hidden layer activations. MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1, - *activations.env, hidden_activations); + *activations.env, RowPtrFromMat(activations.C1)); - // Activation (Gelu), store in act. - RowPtrF multiplier = RowPtrF(allocator, nullptr, 0); + // Activation (Gelu), store in C1. ActivationBatched(layer_weights->layer_config.activation, activations.C1); // Hidden layer -> output layer. MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias, - *activations.env, ffw_out); + *activations.env, RowPtrFromMat(activations.ffw_out)); } // `batch_idx` indicates which row of `x` to write to. @@ -932,11 +919,10 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, } const size_t model_dim = weights.weights_config.model_dim; - const size_t vocab_size = weights.weights_config.vocab_size; const float emb_scaling = EmbeddingScaling(model_dim); HWY_DASSERT(token >= 0); - HWY_DASSERT(token < static_cast(vocab_size)); + HWY_DASSERT(token < static_cast(weights.weights_config.vocab_size)); const hn::ScalableTag df; // Using `Stride` to compute the offset works for both NUQ (because we use an @@ -1263,7 +1249,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, // Apply head embedding into image_tokens of size of the LLM kModelDim. MatMul(activations.x, weights.vit_img_head_kernel, weights.vit_img_head_bias.PackedScale1(), *activations.env, - RowPtrFromMat(activations.env->ctx.allocator, image_tokens)); + RowPtrFromMat(image_tokens)); } // Generates one token for each query. `queries_token` is the previous token @@ -1403,7 +1389,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, // Compute logits from last layer activations. MatMul(activations.x, weights.embedder_input_embedding, /*add=*/nullptr, *activations.env, - RowPtrFromMat(activations.env->ctx.allocator, activations.logits)); + RowPtrFromMat(activations.logits)); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index d1e1964..478470f 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -107,7 +107,7 @@ using LayersOutputFunc = std::function; diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 028a6f1..f9707c8 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -25,7 +25,6 @@ namespace gcpp { struct KVCache { - KVCache() = default; // for std::vector. KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size); // Returns a deep copy of the KVCache. diff --git a/io/io.cc b/io/io.cc index e942068..a22a42b 100644 --- a/io/io.cc +++ b/io/io.cc @@ -115,7 +115,7 @@ class FilePosix : public File { #endif return MapPtr(static_cast(mapping), - DeleterFunc2([mapping_size](void* ptr) { + DeleterFunc([mapping_size](void* ptr) { HWY_ASSERT(munmap(ptr, mapping_size) == 0); })); } diff --git a/io/io.h b/io/io.h index 7e1a18c..bb606f9 100644 --- a/io/io.h +++ b/io/io.h @@ -33,7 +33,7 @@ namespace gcpp { // prefer to define Exists inline because there are multiple io*.cc files. struct Path; -using MapPtr = AlignedPtr2; +using MapPtr = AlignedPtr; // Abstract base class enables multiple I/O backends in the same binary. class File { diff --git a/io/io_win.cc b/io/io_win.cc index fb6d1fa..1f35a96 100644 --- a/io/io_win.cc +++ b/io/io_win.cc @@ -108,7 +108,7 @@ class FileWin : public File { void* ptr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); if (!ptr) return MapPtr(); return MapPtr(static_cast(ptr), - DeleterFunc2([hMapping](void* ptr) { + DeleterFunc([hMapping](void* ptr) { HWY_ASSERT(UnmapViewOfFile(ptr)); HWY_ASSERT(CloseHandle(hMapping)); })); diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index fa49b2d..c178a26 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -105,7 +105,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { MatStorageT b_trans = GenerateTransposedMat(B_extents, pool); const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C = RowPtrFromMat(allocator, c_batch); + const RowPtr C = RowPtrFromMat(c_batch); // Fewer reps for large batch sizes, which take longer. const size_t num_samples = M < 32 ? 20 : 12; diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 52127a4..d4f8eab 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1140,7 +1140,7 @@ void TestAllDot() { for (size_t variant = 0; variant < kVariants; ++variant) { constexpr size_t kTimeReps = hn::AdjustedReps(10); std::array elapsed; - for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) { + for (size_t time_rep = 0; time_rep < kTimeReps; ++time_rep) { const double start = hwy::platform::Now(); dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index f02edef..062672c 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -864,7 +864,7 @@ class MMPerPackage { : args_(args), pkg_idx_(pkg_idx), // May be overwritten with a view of A, if already BF16. - A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())), + A_(args_.env->storage.A(pkg_idx, A.Extents())), range_np_(range_np), mr_(config.MR()), ranges_mc_(config.RangesOfMC(A.Extents().rows)), @@ -905,9 +905,8 @@ class MMPerPackage { // Compute size of per-worker storage for `kNR` row ranges of B. Stack // allocation avoids passing a worker index. static constexpr size_t B_stride_max_ = - MaxStrideForCyclicOffsets(MMStorage::kMaxKC); - static constexpr size_t B_storage_max_ = - kNR * B_stride_max_ + Allocator::MaxQuantum(); + MMStorage::kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16); + static constexpr size_t B_storage_max_ = kNR * B_stride_max_; // Granularity of `ForNP`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. @@ -928,15 +927,14 @@ class MMPerPackage { const size_t K = range_K.Num(); const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K); const size_t B_stride = - StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum()); + Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); // Similar to `loop_nc` below, but here we hoisted `A_view`. args_.env->parallel.ForNP( range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, [&](const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K, - B_stride); + const RowPtrBF B_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -971,8 +969,8 @@ class MMPerPackage { const size_t kc = range_kc.Num(); const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); const RowPtrBF B_view( - args_.env->ctx.allocator, B_storage, kc, - StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum())); + B_storage, kc, + Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -1028,7 +1026,7 @@ class MMPerPackage { const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); const size_t B_stride = - StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum()); + Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. @@ -1037,8 +1035,7 @@ class MMPerPackage { [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K, - B_stride); + const RowPtrBF B_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -1064,8 +1061,8 @@ class MMPerPackage { zone.MaybeEnter("MM.NT_MT_K", args_); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); - const size_t B_stride = StrideForCyclicOffsets( - kc_max, args_.env->ctx.allocator.Quantum()); + const size_t B_stride = + Stride(MatPadding::kOdd, kc_max, sizeof(BF16), line_bytes_); // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. @@ -1091,8 +1088,7 @@ class MMPerPackage { ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max, - B_stride); + const RowPtrBF B_view(B_storage, kc_max, B_stride); // Peel off the first iteration of the kc loop: avoid // zero-initializing `partial` by writing into it. @@ -1172,13 +1168,12 @@ class MMPerPackage { // Autotuning wrapper for `DoDecompressA`. template HWY_INLINE RowPtrBF DecompressA(const MatPtrT& A) const { - const Allocator& allocator = args_.env->ctx.allocator; MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; // If already BF16, maybe return a view: if constexpr (hwy::IsSame()) { // Only if no zero-padding required. const size_t NBF = hn::Lanes(hn::ScalableTag()); - if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A); + if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(A); } if (HWY_LIKELY(autotune.Best())) { diff --git a/ops/matmul.h b/ops/matmul.h index 8e50c60..773322f 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -217,8 +217,7 @@ class MMStorage { : partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), MatPadding::kOdd), // Same stride independent of the actual C.Cols() so we can pre-bind. - partial_(allocator, partial_storage_.Row(0), kMaxN, - partial_storage_.Stride()) { + partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { // Per-package allocation so each can decompress A into its own copy. parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { pkg_A_[pkg_idx].reset(new MatStorageT( @@ -240,12 +239,11 @@ class MMStorage { } // Returns per-package matrix view. - RowPtrBF A(const Allocator& allocator, size_t pkg_idx, - const Extents2D& extents) const { + RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.cols <= kMaxK); - return RowPtrBF(allocator, const_cast(pkg_A_[pkg_idx]->Row(0)), - extents.cols, pkg_A_[pkg_idx]->Stride()); + return RowPtrBF(const_cast(pkg_A_[pkg_idx]->Row(0)), extents.cols, + pkg_A_[pkg_idx]->Stride()); } RowPtrD Partial() const { return partial_; } diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 668a983..5d3f4f1 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -205,7 +205,6 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents, template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env, int line) { - const Allocator& allocator = env.ctx.allocator; hwy::ThreadPool& pool = env.ctx.pools.Pool(); fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName(), @@ -229,8 +228,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, add_storage.SetScale(1.0f); const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C_slow = RowPtrFromMat(allocator, c_slow_batch); - const RowPtr C = RowPtrFromMat(allocator, c_batch); + const RowPtr C_slow = RowPtrFromMat(c_slow_batch); + const RowPtr C = RowPtrFromMat(c_batch); MatMulSlow(a, b_trans, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. diff --git a/ops/ops_test.cc b/ops/ops_test.cc index cf88fd8..780eb61 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -507,7 +507,7 @@ void TestLayerNormSimple() { const size_t kSize = 52; std::vector values(kSize); // Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995 - for (int i = 0; i < kSize; ++i) { + for (size_t i = 0; i < kSize; ++i) { values[i] = (i % 2 == 0) ? 1.0f : -1.0f; } std::vector scale(kSize, 1.2f); diff --git a/util/allocator.cc b/util/allocator.cc index d6c1506..df2575e 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -132,7 +132,11 @@ size_t DetectTotalMiB(size_t page_bytes) { Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { line_bytes_ = DetectLineBytes(); + // Ensure MaxLineBytes() is an upper bound. + HWY_ASSERT(MaxLineBytes() >= LineBytes()); + vector_bytes_ = hwy::VectorBytes(); + step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); base_page_bytes_ = DetectPageSize(); quantum_bytes_ = step_bytes_; // may overwrite below @@ -165,8 +169,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { // Ensure pages meet the alignment requirements of `AllocBytes`. HWY_ASSERT(base_page_bytes_ >= quantum_bytes_); quantum_bytes_ = base_page_bytes_; - // Ensure MaxQuantum() is an upper bound. - HWY_ASSERT(MaxQuantum() >= Quantum()); should_bind_ = true; } else { HWY_WARN( @@ -175,9 +177,6 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { } } } - - HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0); - quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1; } size_t Allocator::FreeMiB() const { @@ -201,7 +200,7 @@ size_t Allocator::FreeMiB() const { #endif } -AlignedPtr2 Allocator::AllocBytes(size_t bytes) const { +AlignedPtr Allocator::AllocBytes(size_t bytes) const { // If we are not binding, the Highway allocator is cheaper than `mmap`, and // defends against 2K aliasing. if (!should_bind_) { @@ -217,10 +216,9 @@ AlignedPtr2 Allocator::AllocBytes(size_t bytes) const { // alignment scheme in aligned_allocator.cc and does not work for // already-aligned pointers as returned by `mmap`, hence we wrap the Highway // pointer in our own deleter. - return AlignedPtr2(p.release(), DeleterFunc2([](void* ptr) { - hwy::FreeAlignedBytes(ptr, nullptr, - nullptr); - })); + return AlignedPtr(p.release(), DeleterFunc([](void* ptr) { + hwy::FreeAlignedBytes(ptr, nullptr, nullptr); + })); } // Binding, or large vector/cache line size: use platform-specific allocator. @@ -234,17 +232,16 @@ AlignedPtr2 Allocator::AllocBytes(size_t bytes) const { const int fd = -1; void* p = mmap(0, bytes, prot, flags, fd, off_t{0}); if (p == MAP_FAILED) p = nullptr; - return AlignedPtr2(static_cast(p), - DeleterFunc2([bytes](void* ptr) { - HWY_ASSERT(munmap(ptr, bytes) == 0); - })); + return AlignedPtr( + static_cast(p), + DeleterFunc([bytes](void* ptr) { HWY_ASSERT(munmap(ptr, bytes) == 0); })); #elif HWY_OS_WIN const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_); - return AlignedPtr2( + return AlignedPtr( static_cast(_aligned_malloc(bytes, alignment)), - DeleterFunc2([](void* ptr) { _aligned_free(ptr); })); + DeleterFunc([](void* ptr) { _aligned_free(ptr); })); #else - return AlignedPtr2(nullptr, DeleterFunc2()); + return AlignedPtr(nullptr, DeleterFunc()); #endif } diff --git a/util/allocator.h b/util/allocator.h index 9cf9b60..a996d6f 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -34,13 +34,13 @@ namespace gcpp { // Custom deleter for types without a dtor, but where the deallocation requires // state, e.g. a lambda with *by-value* capture. -class DeleterFunc2 { +class DeleterFunc { public: // `MatOwnerT` requires this to be default-constructible. - DeleterFunc2() = default; + DeleterFunc() = default; template - DeleterFunc2(const Closure& free_closure) : free_func_(free_closure) {} + DeleterFunc(const Closure& free_closure) : free_func_(free_closure) {} template void operator()(T* p) const { @@ -52,10 +52,10 @@ class DeleterFunc2 { }; // Wrapper that also calls the destructor for each element being deallocated. -class DeleterDtor2 { +class DeleterDtor { public: - DeleterDtor2() {} - DeleterDtor2(size_t num, DeleterFunc2 free) : num_(num), free_(free) {} + DeleterDtor() {} + DeleterDtor(size_t num, DeleterFunc free) : num_(num), free_(free) {} template void operator()(T* p) const { @@ -67,15 +67,15 @@ class DeleterDtor2 { private: size_t num_; - DeleterFunc2 free_; + DeleterFunc free_; }; // Unique (move-only) pointer to aligned POD T, which can be an array or class. template -using AlignedPtr2 = std::unique_ptr; +using AlignedPtr = std::unique_ptr; // Unique (move-only) pointer to an aligned array of non-POD T. template -using AlignedClassPtr2 = std::unique_ptr; +using AlignedClassPtr = std::unique_ptr; // Both allocation, binding, and row accessors depend on the sizes of memory // pages and cache lines. To avoid having to pass `Allocator&` everywhere, we @@ -90,26 +90,24 @@ class Allocator { // Bytes per cache line, or a reasonable guess if unknown. Used to choose // ranges such that there will be no false sharing. size_t LineBytes() const { return line_bytes_; } + // Upper bound on `LineBytes()`, for stack allocations. + static constexpr size_t MaxLineBytes() { return 256; } // Bytes per full vector. Used to compute loop steps. size_t VectorBytes() const { return vector_bytes_; } // Work granularity that avoids false sharing and partial vectors. // = HWY_MAX(LineBytes(), VectorBytes()) size_t StepBytes() const { return step_bytes_; } + // File size multiple required for memory mapping. size_t BasePageBytes() const { return base_page_bytes_; } + // Either StepBytes or BasePageBytes if NUMA. size_t QuantumBytes() const { return quantum_bytes_; } template + // For rounding down elements to the page size in `BindB/BindC`. size_t Quantum() const { return QuantumBytes() / sizeof(T); } - // Upper bound on `Quantum()`, for stack allocations. - template - static constexpr size_t MaxQuantum() { - return 4096 / sizeof(T); - } - // = QuantumBytes() / StepBytes() - 1 - size_t QuantumStepMask() const { return quantum_step_mask_; } // L1 and L2 are typically per core. size_t L1Bytes() const { return l1_bytes_; } @@ -123,35 +121,35 @@ class Allocator { // Returns byte pointer aligned to `QuantumBytes()`, without calling // constructors nor destructors on deletion. Type-erased so this can be // implemented in `allocator.cc` and called by `MatOwner`. - AlignedPtr2 AllocBytes(size_t bytes) const; + AlignedPtr AllocBytes(size_t bytes) const; // Returns pointer aligned to `QuantumBytes()`, without calling constructors // nor destructors on deletion. template - AlignedPtr2 Alloc(size_t num) const { + AlignedPtr Alloc(size_t num) const { const size_t bytes = num * sizeof(T); // Fail if the `bytes = num * sizeof(T)` computation overflowed. HWY_ASSERT(bytes / sizeof(T) == num); - AlignedPtr2 p8 = AllocBytes(bytes); - return AlignedPtr2(HWY_RCAST_ALIGNED(T*, p8.release()), - p8.get_deleter()); + AlignedPtr p8 = AllocBytes(bytes); + return AlignedPtr(HWY_RCAST_ALIGNED(T*, p8.release()), + p8.get_deleter()); } // Same as Alloc, but calls constructor(s) with `args` and the deleter will // call destructor(s). template - AlignedClassPtr2 AllocClasses(size_t num, Args&&... args) const { + AlignedClassPtr AllocClasses(size_t num, Args&&... args) const { const size_t bytes = num * sizeof(T); // Fail if the `bytes = num * sizeof(T)` computation overflowed. HWY_ASSERT(bytes / sizeof(T) == num); - AlignedPtr2 p8 = AllocBytes(bytes); + AlignedPtr p8 = AllocBytes(bytes); T* p = HWY_RCAST_ALIGNED(T*, p8.release()); for (size_t i = 0; i < num; ++i) { new (p + i) T(std::forward(args)...); } - return AlignedClassPtr2(p, DeleterDtor2(num, p8.get_deleter())); + return AlignedClassPtr(p, DeleterDtor(num, p8.get_deleter())); } // Returns whether `BindMemory` can/should be called, i.e. we have page-level @@ -170,7 +168,6 @@ class Allocator { size_t step_bytes_; size_t base_page_bytes_; size_t quantum_bytes_; - size_t quantum_step_mask_; size_t l1_bytes_ = 0; size_t l2_bytes_ = 0; diff --git a/util/mat.cc b/util/mat.cc index 5950596..d40088f 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -89,38 +89,31 @@ void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) { } } -// Returns `num` rounded up to an odd number of cache lines. This would also -// prevent 4K aliasing and is coprime with the cache associativity, which -// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`. -static size_t RoundUpToOddLines(size_t num, size_t line_bytes, - size_t element_bytes) { - HWY_DASSERT(line_bytes >= 32); - HWY_DASSERT(line_bytes % element_bytes == 0); - const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes); - const size_t padded_num = (lines | 1) * line_bytes / element_bytes; - HWY_DASSERT(padded_num >= num); - return padded_num; -} - -static size_t Stride(const Allocator& allocator, const MatPtr& mat, - MatPadding padding) { +size_t Stride(MatPadding padding, size_t cols, size_t element_bytes, + size_t line_bytes) { switch (padding) { case MatPadding::kPacked: default: - return mat.Cols(); - case MatPadding::kOdd: - return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(), - mat.ElementBytes()); - case MatPadding::kCyclic: - return StrideForCyclicOffsets( - mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes()); + return cols; + case MatPadding::kOdd: { + // Round up to an odd number of cache lines to prevent 4K aliasing and + // reduce conflict misses (coprime with the cache associativity). + HWY_DASSERT(line_bytes >= 32); + HWY_DASSERT(line_bytes % element_bytes == 0); + const size_t lines = hwy::DivCeil(cols * element_bytes, line_bytes); + const size_t padded_cols = (lines | 1) * line_bytes / element_bytes; + HWY_DASSERT(padded_cols >= cols); + return padded_cols; + } } } -void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) { +void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { const bool is_nuq = mat.GetType() == Type::kNUQ; + if (is_nuq) padding = MatPadding::kPacked; const Allocator& allocator = ThreadingContext::Get().allocator; - const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding); + const size_t stride = + Stride(padding, mat.Cols(), mat.ElementBytes(), allocator.LineBytes()); const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // might not be enough, hence add extra. `MatT` is at least one byte, which diff --git a/util/mat.h b/util/mat.h index 7d9113a..dd822f4 100644 --- a/util/mat.h +++ b/util/mat.h @@ -28,7 +28,7 @@ #include "compression/shared.h" // Type #include "gemma/tensor_info.h" #include "io/fields.h" -#include "util/allocator.h" // AlignedPtr2 +#include "util/allocator.h" // AlignedPtr #include "util/basics.h" // Extents2D // IWYU pragma: end_exports #include "hwy/base.h" @@ -339,24 +339,6 @@ void ZeroInit(MatPtr& mat); // F32/F64 only. void RandInit(MatPtr& mat, float stddev, std::mt19937& gen); -// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If -// `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is typically 4KiB. -// To avoid remote accesses, we would thus pad each row to that, which results -// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that -// by pulling rows forward by a cyclic offset, which is still a multiple of the -// cache line size. This requires an additional `Allocator::QuantumBytes()` of -// padding after also rounding up to that, which considerably increases size for -// tall and skinny tensors. -static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) { - return hwy::RoundUpTo(cols, quantum) + quantum; -} -// Constexpr version (upper bound) for allocating storage in MatMul. -template -constexpr size_t MaxStrideForCyclicOffsets(size_t cols) { - constexpr size_t quantum = Allocator::MaxQuantum(); - return hwy::RoundUpTo(cols, quantum) + quantum; -} - // Our tensors are always row-major. This enum indicates how much (if any) // padding comes after each row. enum class MatPadding { @@ -373,11 +355,14 @@ enum class MatPadding { // Enough to round up to an odd number of cache lines, which can reduce // cache conflict misses or 4K aliasing. kOdd, - // Enough to enable the "cyclic offsets" optimization for `MatMul`. - kCyclic, }; -// Type-erased, allows storing `AlignedPtr2` for various T in the same +// The stride (offset in elements between rows) that `MatOwner/MatStorageT` +// will use. +size_t Stride(MatPadding padding, size_t cols, size_t element_bytes, + size_t line_bytes); + +// Type-erased, allows storing `AlignedPtr` for various T in the same // vector. class MatOwner { public: @@ -390,7 +375,7 @@ class MatOwner { void AllocateFor(MatPtr& mat, MatPadding padding); private: - AlignedPtr2 storage_; + AlignedPtr storage_; }; // Multiple `MatOwner`, with support for parallel allocation. @@ -443,84 +428,40 @@ MatStorageT MakePacked(const char* name, size_t rows, size_t cols) { } // Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with -// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets. +// seekable (non-NUQ) T. #pragma pack(push, 1) // power of two size template class RowPtr { public: - RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols, - size_t stride) + RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) : row0_(row0), - stride_(stride), - // TODO: disabled because otherwise we see non-deterministic results. - row_mask_(0), - // static_cast(allocator.QuantumStepMask() & 0xFFFFFFFFu)), cols_(static_cast(cols)), - step_bytes_(static_cast(allocator.StepBytes())), - quantum_bytes_(allocator.QuantumBytes()) { + stride_(static_cast(stride)) { HWY_DASSERT(stride >= cols); - HWY_DASSERT(row_mask_ != ~uint32_t{0}); - if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) { - row_mask_ = 0; - if constexpr (HWY_IS_DEBUG_BUILD) { - static bool once; - if (stride != cols && !once) { - once = true; - HWY_WARN( - "Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), " - "T=%zu; this forces us to disable cyclic offsets.", - stride, cols, sizeof(T)); - } - } - } } - RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols) - : RowPtr(allocator, row0, cols, cols) {} + RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} - T* HWY_RESTRICT Row(size_t r) const { - // How much of the previous row's padding to consume. - const size_t pad_bytes = (r & row_mask_) * step_bytes_; - HWY_DASSERT(pad_bytes < static_cast(quantum_bytes_)); - return row0_ + stride_ * r - pad_bytes; - } + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } size_t Cols() const { return static_cast(cols_); } - size_t Stride() const { return stride_; } + size_t Stride() const { return static_cast(stride_); } void SetStride(size_t stride) { HWY_DASSERT(stride >= Cols()); stride_ = stride; - // The caller might not have padded enough, so disable the padding in Row(). - // Rows will now be exactly `stride` elements apart. This is used when - // writing to the KV cache via MatMul. - row_mask_ = 0; } // Returns 2D subrange whose top-left is `r, c` and width is `cols`. RowPtr View(size_t r, size_t c, size_t cols) const { HWY_DASSERT(c < Cols()); HWY_DASSERT(cols <= Cols() - c); - return RowPtr(Row(r) + c, cols, stride_, row_mask_, step_bytes_, - quantum_bytes_); + return RowPtr(Row(r) + c, cols, stride_); } private: - // For `View()`. - RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask, - uint32_t step_bytes, uint32_t quantum_bytes) - : row0_(new_row0), - stride_(stride), - row_mask_(row_mask), - cols_(new_cols), - step_bytes_(step_bytes), - quantum_bytes_(quantum_bytes) {} - T* HWY_RESTRICT row0_; - size_t stride_; - uint32_t row_mask_; uint32_t cols_; - uint32_t step_bytes_; - uint32_t quantum_bytes_; + uint32_t stride_; }; #pragma pack(pop) @@ -528,14 +469,12 @@ using RowPtrBF = RowPtr; using RowPtrF = RowPtr; using RowPtrD = RowPtr; -// TODO: remove allocator arg once kCyclic is removed. template -RowPtr RowPtrFromMat(const Allocator& allocator, - const MatPtrT& row_vectors) { +RowPtr RowPtrFromMat(const MatPtrT& row_vectors) { // RowPtr is non-const for MatMul C, but is also used for A which is const. // Callers are responsible for checking their usage of RowPtr. - return RowPtr(allocator, const_cast(row_vectors.Row(0)), - row_vectors.Cols(), row_vectors.Stride()); + return RowPtr(const_cast(row_vectors.Row(0)), row_vectors.Cols(), + row_vectors.Stride()); } } // namespace gcpp diff --git a/util/threading.h b/util/threading.h index 205226e..5e13dae 100644 --- a/util/threading.h +++ b/util/threading.h @@ -38,7 +38,7 @@ namespace gcpp { // Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows // moving because it is a typedef to `std::unique_ptr`. -using PoolPtr = AlignedClassPtr2; +using PoolPtr = AlignedClassPtr; // Creates a hierarchy of thread pools according to `BoundedTopology`: one with // a thread per enabled package; for each of those, one with a thread per