From e890d46f3027a992ca62c851f14e811096f07c8d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 16 May 2025 07:41:36 -0700 Subject: [PATCH] 1.31x batch prefill, 1.24x batch decode speedup: NUMA binding Only the weights; binding MatMul output worsens batch=1 prefill. Update gemma_batch_bench to use --decode_qbatch. Fix/remove prefill_activations in gemma-inl.h. Refactor: use BasePageBytes directly when binding Move BindB/C to .cc by de-templatizing Remove MatOwners::AllocateFor because it is weights-specific (binding or not) Disband MatOwners, replace with vector PiperOrigin-RevId: 759610477 --- BUILD.bazel | 3 +- backprop/test_util.h | 3 +- compression/python/compression_clif_aux.cc | 5 +- evals/gemma_batch_bench.cc | 28 +++++----- gemma/activations.h | 2 + gemma/gemma-inl.h | 29 +++++----- gemma/weights.cc | 39 ++++++++++---- gemma/weights.h | 42 +++++++++------ ops/bench_matmul.cc | 11 ++-- ops/matmul.cc | 61 +++++++++++++++++++++- ops/matmul.h | 59 ++++----------------- util/allocator.h | 10 ++-- util/mat.cc | 13 ----- util/mat.h | 23 +------- 14 files changed, 169 insertions(+), 159 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 428a00e..b673341 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -172,7 +172,6 @@ cc_library( "//io:fields", "@highway//:hwy", "@highway//:profiler", - "@highway//:thread_pool", ], ) @@ -217,6 +216,7 @@ cc_library( ":configs", ":mat", ":model_store", + ":ops", ":tensor_info", ":threading_context", "//compression:compress", @@ -281,7 +281,6 @@ cc_library( ":allocator", ":basics", ":mat", - ":threading", ":threading_context", "//compression:compress", "@highway//:algo", diff --git a/backprop/test_util.h b/backprop/test_util.h index 10f0386..8f32cbf 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -20,6 +20,7 @@ #include #include +#include #include "gtest/gtest.h" #include "gemma/configs.h" @@ -75,7 +76,7 @@ class WeightsWrapper { ModelWeightsPtrs& get() { return weights_; } private: - MatOwners owners_; + std::vector owners_; ModelWeightsPtrs weights_; }; diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index f8ca0ca..5051a87 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -67,7 +67,8 @@ class SbsWriterImpl : public ISbsWriter { } mat.AppendTo(serialized_mat_ptrs_); - mat_owners_.AllocateFor(mat, MatPadding::kPacked); + mat_owners_.push_back(MatOwner()); + mat_owners_.back().AllocateFor(mat, MatPadding::kPacked); // Handle gemma_export_test's MockArray. Write blobs so that the test // succeeds, but we only have 10 floats, not the full tensor. @@ -121,7 +122,7 @@ class SbsWriterImpl : public ISbsWriter { } hwy::ThreadPool& pool_; - MatOwners mat_owners_; + std::vector mat_owners_; CompressWorkingSet working_set_; BlobWriter writer_; std::vector serialized_mat_ptrs_; diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 2976e1e..b691706 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -22,13 +22,13 @@ #include "gemma/configs.h" #include "gemma/gemma.h" #include "hwy/base.h" +#include "hwy/profiler.h" #include "hwy/tests/hwy_gtest.h" // This test can be run manually with the downloaded gemma weights. // To run the test, pass the following flags: -// --model --tokenizer --weights +// --tokenizer --weights // It should pass for the following models: -// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it, // Gemma2: gemma2-2b-it, 9b-it, 27b-it, namespace gcpp { @@ -76,26 +76,25 @@ class GemmaTest : public ::testing::Test { return replies; } - void GenerateTokens(std::vector &kQA, size_t num_questions) { + void GenerateTokens(const std::vector& questions) { ASSERT_NE(s_env->GetGemma(), nullptr); + // Fills prompts round robin from `questions` until the desired batch size. std::vector inputs; - inputs.reserve(num_questions); - for (size_t i = 0; i < num_questions; ++i) { - inputs.push_back(kQA[i]); + inputs.reserve(s_env->MutableConfig().decode_qbatch_size); + size_t qpos = 0; + for (size_t i = 0; i < inputs.capacity(); ++i) { + inputs.push_back(questions[qpos++]); + if (qpos == questions.size()) qpos = 0; } std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < num_questions; ++i) { - std::string response = responses.at(i); - fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str()); + for (size_t i = 0; i < inputs.size(); ++i) { + fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } } }; TEST_F(GemmaTest, RandomQuestionsBatched) { - s_env->MutableConfig().decode_qbatch_size = 3; - s_env->MutableConfig().verbosity = 5; - static std::vector kQA = { {"Write me a poem about Australia?"}, {"What's the history of Denmark?"}, @@ -130,8 +129,9 @@ TEST_F(GemmaTest, RandomQuestionsBatched) { {"Tell me about space travel."}, {"Explain to me how electric cars work."}, }; - static const size_t kNum = kQA.size(); - GenerateTokens(kQA, kNum); + s_env->MutableConfig().verbosity = 5; + GenerateTokens(kQA); + PROFILER_PRINT_RESULTS(); } } // namespace } // namespace gcpp diff --git a/gemma/activations.h b/gemma/activations.h index 6f615cf..5f19dd8 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -83,6 +83,8 @@ struct Activations { env(env) { HWY_ASSERT(batch_size != 0); + + // Note that BindC on any MatMul output considerably slows down Prefill. } void SetBatchSize(size_t batch_size) { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index cd10dfb..255d859 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -21,7 +21,7 @@ #include #include // std::min -#include +#include // std::make_unique #include #include "gemma/activations.h" @@ -1055,7 +1055,8 @@ HWY_NOINLINE void Prefill( // intensity, and so we are eventually compute-limited. We could devote some // threads to parallelizing over queries, but for simplicity we assign them // all to MatMul. - const size_t max_tbatch_size = activations.x.Rows(); + const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; + HWY_DASSERT(max_tbatch_size <= activations.x.Rows()); // For each query. `qi` is within the batch, not the global query index. for (size_t qi = 0; qi < num_queries; ++qi) { @@ -1429,18 +1430,10 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, // Prefill stops before min_prompt_size - 1 because the last prompt // token is the first input token for generation. timing_info.prefill_start = hwy::platform::Now(); - // If tbatch is larger than the qbatch we already have in `activations`, then - // allocate prefill_activations, otherwise reuse. - const bool use_prefill_activations = - runtime_config.prefill_tbatch_size > activations.x.Rows(); - Activations prefill_activations( - weights.weights_config, - use_prefill_activations ? runtime_config.prefill_tbatch_size : 0, - activations.env); + // Note that Prefill calls activations.SetBatchSize, so we reset it below. Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, - query_idx_start, weights, - use_prefill_activations ? prefill_activations : activations, - runtime_config, div_seq_len, kv_caches); + query_idx_start, weights, activations, runtime_config, div_seq_len, + kv_caches); // Compute the number of tokens that were prefilled and notify timing_info. size_t prefilled_tokens = 0; for (size_t qi = 0; qi < num_queries; ++qi) { @@ -1448,6 +1441,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, } timing_info.NotifyPrefill(prefilled_tokens); // queries_pos are incremented by Prefill. + activations.SetBatchSize(num_queries); // Storage for the last generated token from each query, passed to the next // Transformer() call. @@ -1489,8 +1483,10 @@ void GenerateSingleT(const ModelStore& model, constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; + const size_t max_batch_size = + HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); // TODO: move into Gemma? - Activations activations(model.Config(), kNumQueries, env); + Activations activations(model.Config(), max_batch_size, env); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries); @@ -1523,7 +1519,9 @@ void GenerateBatchT(const ModelStore& model, } } - Activations activations(model.Config(), max_qbatch_size, env); + const size_t max_batch_size = + HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); + Activations activations(model.Config(), max_batch_size, env); for (size_t qbatch_start = 0; qbatch_start < num_queries; qbatch_start += max_qbatch_size) { @@ -1557,6 +1555,7 @@ void GenerateImageTokensT(const ModelStore& model, prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); Activations prefill_activations(vit_config, vit_config.seq_len, env); + prefill_activations.SetBatchSize(prefill_runtime_config.prefill_tbatch_size); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(weights, prefill_runtime_config, image, image_tokens, prefill_activations); diff --git a/gemma/weights.cc b/gemma/weights.cc index eb65f48..0269c60 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -16,10 +16,10 @@ #include "gemma/weights.h" #include +#include #include +#include -#include -#include #include #include #include @@ -30,6 +30,7 @@ #include "gemma/configs.h" #include "gemma/model_store.h" #include "io/blob_store.h" +#include "ops/matmul.h" // MMParallel #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" @@ -46,7 +47,7 @@ namespace gcpp { static void InitAttWeightsNUQ(const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, MatPtrT& att_weights, - MatOwners& mat_owners) { + std::vector& mat_owners) { if (!attn_vec_einsum_w.HasPtr()) return; HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); @@ -91,11 +92,29 @@ static void SplitW1NUQ(const LayerConfig& layer_config) { } template <> -void LayerWeightsPtrs::Fixup(MatOwners& mat_owners) { +void LayerWeightsPtrs::Fixup(std::vector& mat_owners) { InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners); SplitW1NUQ(layer_config); } +// Allocates multiple in parallel and binds to NUMA nodes. +static void AllocateAndBindAll(const std::vector& mats, + MatPadding padding, + std::vector& owners, + hwy::ThreadPool& pool) { + const size_t start = owners.size(); + owners.resize(start + mats.size()); + + MMParallel parallel(ThreadingContext::Get()); + + // Allocate in parallel because faulting in large tensors is slow. + pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) { + owners[start + task].AllocateFor(*mats[task], padding); + // TODO(janwas): MatMul outputs will later also be BF16. + BindB(*mats[task], sizeof(float), parallel); + }); +} + // Parallel I/O into allocated memory, or mapped view of file. The latter is // better when the file is huge, but page faults add noise to measurements. enum class Mode { kRead, kMap }; @@ -209,10 +228,10 @@ static void ReadBatches(const BlobReader& reader, } // Aborts on error. -static void MapOrRead(const std::vector& mats, BlobReader& reader, - const std::vector& ranges, Tristate map, - MatOwners& mat_owners, const MatPadding padding, - hwy::ThreadPool& pool) { +static void MapOrReadAll(const std::vector& mats, BlobReader& reader, + const std::vector& ranges, Tristate map, + std::vector& mat_owners, + const MatPadding padding, hwy::ThreadPool& pool) { HWY_ASSERT(mats.size() == ranges.size()); if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) { @@ -226,7 +245,7 @@ static void MapOrRead(const std::vector& mats, BlobReader& reader, { PROFILER_ZONE("Startup.Weights.Allocate"); // NOTE: this changes the stride of `mats`! - mat_owners.AllocateFor(mats, padding, pool); + AllocateAndBindAll(mats, padding, mat_owners, pool); } const std::vector batches = @@ -259,7 +278,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, }); }); - MapOrRead(mats, reader, ranges, map, mat_owners_, padding, pool); + MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool); Fixup(pool); } diff --git a/gemma/weights.h b/gemma/weights.h index af94c58..aa184db 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -27,11 +27,12 @@ #include #include "compression/types.h" // IsF32 -#include "gemma/configs.h" // ModelConfig -#include "gemma/model_store.h" // ModelStore -#include "gemma/tensor_info.h" // TensorInfoRegistry -#include "io/blob_store.h" // BlobWriter -#include "util/mat.h" // MatPtr +#include "gemma/configs.h" // ModelConfig +#include "gemma/model_store.h" // ModelStore +#include "gemma/tensor_info.h" // TensorInfoRegistry +#include "io/blob_store.h" // BlobWriter +#include "ops/matmul.h" // MatMulEnv +#include "util/mat.h" // MatPtr #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -191,7 +192,7 @@ struct LayerWeightsPtrs { MatPtrT gating_einsum_w1; MatPtrT gating_einsum_w2; MatPtrT linear_w; - // We don't yet have an RMSNorm that accepts all Weight. + // > W8 is likely helpful. MatPtrT pre_attention_norm_scale; MatPtrT pre_ffw_norm_scale; MatPtrT post_attention_norm_scale; @@ -299,17 +300,18 @@ struct LayerWeightsPtrs { // Allocates memory for all the tensors in the layer. Note that this is slow // (non-parallel) and only used for a stand-alone layer. - void AllocateForTest(MatOwners& mat_owners) { + void AllocateForTest(std::vector& mat_owners) { ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { // `backprop/` does not use row accessors and hence requires kPacked. - mat_owners.AllocateFor(t.mat, MatPadding::kPacked); + mat_owners.push_back(MatOwner()); + mat_owners.back().AllocateFor(t.mat, MatPadding::kPacked); }); } // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. - void Fixup(MatOwners& mat_owners) { + void Fixup(std::vector& mat_owners) { InitAttWeights(mat_owners); SplitW1(); SplitAttW1(); @@ -317,7 +319,7 @@ struct LayerWeightsPtrs { private: // Copies att_weights from `attn_vec_einsum_w`. - void InitAttWeights(MatOwners& mat_owners) { + void InitAttWeights(std::vector& mat_owners) { // We only use this tensor for Gemma layers. if (layer_config.type != LayerAttentionType::kGemma) return; @@ -343,7 +345,8 @@ struct LayerWeightsPtrs { { static std::mutex m; std::lock_guard lock(m); - mat_owners.AllocateFor(att_weights, MatPadding::kOdd); + mat_owners.push_back(MatOwner()); + mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd); } const size_t T_bytes = att_weights.ElementBytes(); @@ -575,7 +578,8 @@ struct ModelWeightsPtrs { // Instead of reading, only allocates memory for all tensors. Used by // `optimizer.cc` via the `Gemma` constructor without weights. - void AllocateForTest(MatOwners& mat_owners, hwy::ThreadPool& pool) { + void AllocateForTest(std::vector& mat_owners, + hwy::ThreadPool& pool) { // First get a list of all the tensors. std::vector all_mat; all_mat.reserve(10 * c_layers.size()); @@ -583,14 +587,20 @@ struct ModelWeightsPtrs { all_mat.push_back(&t.mat); }); - // `backprop/` does not use row accessors and hence requires kPacked. - mat_owners.AllocateFor(all_mat, MatPadding::kPacked, pool); + const size_t start = mat_owners.size(); + mat_owners.resize(start + all_mat.size()); + + // Allocate in parallel because faulting in large tensors is slow. + pool.Run(0, all_mat.size(), [&](uint64_t task, size_t /*thread*/) { + // `backprop/` does not use row accessors and hence requires kPacked. + mat_owners[start + task].AllocateFor(*all_mat[task], MatPadding::kPacked); + }); } // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Must be called after reading and // updating the attention weights. - void Fixup(MatOwners& mat_owners, hwy::ThreadPool& pool) { + void Fixup(std::vector& mat_owners, hwy::ThreadPool& pool) { pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { GetLayer(layer)->Fixup(mat_owners); }); @@ -666,7 +676,7 @@ class WeightsOwner { std::unique_ptr> nuq_weights_; // Owns the memory referenced by all `MatPtr`. - MatOwners mat_owners_; + std::vector mat_owners_; }; } // namespace gcpp diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index b171180..1f6aa19 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -79,7 +79,6 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, // M = A rows, K = A cols, N = C cols. template void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { - const Allocator& allocator = env.ctx.allocator; hwy::ThreadPool& pool = env.ctx.pools.Pool(0); if (env.print_config || env.print_measurement) { fprintf(stderr, "\n"); @@ -92,8 +91,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - MatStorageT c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd); - MatStorageT c_batch("c_batch", C_extents, MatPadding::kOdd); + MatStorageT c_slow_mat("c_slow_batch", C_extents, MatPadding::kOdd); + MatStorageT c_mat("c_batch", C_extents, MatPadding::kOdd); MatStorageT add_storage("add", Extents2D(), MatPadding::kPacked); if (add) { @@ -105,7 +104,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { MatStorageT b_trans = GenerateTransposedMat(B_extents, pool); const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C = RowPtrFromMat(c_batch); + const RowPtr C = RowPtrFromMat(c_mat); // Fewer reps for large batch sizes, which take longer. const size_t num_samples = M < 32 ? 20 : 12; @@ -115,8 +114,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(allocator, sizeof(TC), b_trans, env.parallel); - BindC(allocator, A_extents.rows, C, env.parallel); + BindB(b_trans, sizeof(TC), env.parallel); + BindC(c_mat, env.parallel); Tristate use_spinning = Tristate::kDefault; env.ctx.pools.MaybeStartSpinning(use_spinning); diff --git a/ops/matmul.cc b/ops/matmul.cc index e4554a1..2f2f795 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -25,10 +25,12 @@ #include "util/allocator.h" #include "util/basics.h" -#include "util/threading.h" +#include "util/mat.h" +#include "util/threading_context.h" #include "hwy/base.h" #include "hwy/detect_targets.h" #include "hwy/per_target.h" +#include "hwy/profiler.h" #include "hwy/timer.h" namespace gcpp { @@ -386,7 +388,7 @@ std::vector MMCandidates(const Allocator& allocator, size_t M, // rows for that. static size_t NPMultiple(const Allocator& allocator, size_t N, size_t sizeof_TC, size_t nr, size_t num_packages) { - size_t np_multiple = allocator.QuantumBytes() / sizeof_TC; + size_t np_multiple = allocator.BasePageBytes() / sizeof_TC; // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For // `N` < 4096, this can cause significant load imbalance. If split unevenly, // choose a smaller multiple. @@ -423,4 +425,59 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) have_timer_stop = hwy::platform::HaveTimerStop(cpu100); } +void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { + Allocator& allocator = parallel.allocator(); + if (!allocator.ShouldBind()) return; + if (B.Rows() == 1) return; + + PROFILER_ZONE("Startup.BindB"); + + const IndexRangePartition ranges_np = + parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR); + for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { + const IndexRange& rows_b = ranges_np.Range(pkg_idx); + const size_t node = parallel.Node(pkg_idx); + uintptr_t begin = + reinterpret_cast(B.RowT(rows_b.begin())); + uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); + // B row padding is less than the page size, so only bind the subset that + // is page-aligned. + begin = hwy::RoundUpTo(begin, allocator.BasePageBytes()); + end = hwy::RoundDownTo(end, allocator.BasePageBytes()); + if (HWY_LIKELY(begin != end)) { + allocator.BindMemory(reinterpret_cast(begin), end - begin, node); + } + } +} + +// C is BF16/float, or double for partial +void BindC(const MatPtr& C, MMParallel& parallel) { + Allocator& allocator = parallel.allocator(); + if (!allocator.ShouldBind()) return; + + PROFILER_ZONE("Startup.BindC"); + + const IndexRangePartition ranges_np = parallel.RangesOfNP( + MMParallel::kMaxPackages, C.Cols(), C.ElementBytes(), kNR); + bool ok = true; + for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { + const IndexRange& cols_c = ranges_np.Range(pkg_idx); + // `BindMemory` requires page alignment. These are in bytes. + const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(), + allocator.BasePageBytes()); + const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), + allocator.BasePageBytes()); + + const size_t node = parallel.Node(pkg_idx); + for (size_t im = 0; im < C.Rows(); ++im) { + ok &= allocator.BindMemory(C.MutableRowT(im) + begin, + end - begin, node); + } + } + if (HWY_UNLIKELY(!ok)) { + HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", C.Rows(), C.Cols(), + ranges_np.NumTasks()); + } +} + } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index 2ad369e..fc4e8c1 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -46,6 +46,8 @@ namespace gcpp { // `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`. constexpr size_t kNR = 4; +// Mostly stateless, can be constructed on the fly by weights.cc, but captures +// the singleton ThreadingContext to reduce MatMul call overhead. class MMParallel { public: static constexpr size_t kMaxPackages = 4; @@ -55,6 +57,8 @@ class MMParallel { HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages); } + Allocator& allocator() const { return ctx_.allocator; } + // Initial static partitioning of B rows across packages. IndexRangePartition RangesOfNP(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr) const; @@ -168,31 +172,9 @@ class MMParallel { ThreadingContext& ctx_; }; -template // BF16/float for C, double for partial -void BindC(const Allocator& allocator, size_t M, const RowPtr& C, - MMParallel& parallel) { - if (!allocator.ShouldBind()) return; - - const IndexRangePartition ranges_np = - parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR); - const size_t quantum = allocator.Quantum(); - bool ok = true; - for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { - const IndexRange& cols_c = ranges_np.Range(pkg_idx); - const size_t node = parallel.Node(pkg_idx); - for (size_t im = 0; im < M; ++im) { - // `BindMemory` requires page alignment. - const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum); - const size_t end = hwy::RoundDownTo(cols_c.end(), quantum); - ok &= allocator.BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC), - node); - } - } - if (HWY_UNLIKELY(!ok)) { - HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", M, C.Cols(), - ranges_np.NumTasks()); - } -} +void BindB(const MatPtr& B, size_t sizeof_TC, MMParallel& parallel); +// C is BF16/float, or double for partial. +void BindC(const MatPtr& C, MMParallel& parallel); // Per-package storage for packed A, and one global C-shaped `partial` for // accumulating partial dot products (sections of K). @@ -227,7 +209,7 @@ class MMStorage { const size_t node = parallel.Node(pkg_idx); size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * pkg_A_[pkg_idx]->ElementBytes(); - bytes = hwy::RoundDownTo(bytes, allocator.QuantumBytes()); + bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { HWY_WARN("Failed to bind memory for package %zu", pkg_idx); } @@ -235,7 +217,7 @@ class MMStorage { }); // Avoid cross-package accesses. - BindC(allocator, kMaxM, partial_, parallel); + BindC(partial_storage_, parallel); } // Returns per-package matrix view. @@ -681,29 +663,6 @@ struct MMZone { }; #endif // PROFILER_ENABLED -template -void BindB(const Allocator& allocator, size_t sizeof_TC, const MatPtrT& B, - MMParallel& parallel) { - if (!allocator.ShouldBind()) return; - - const IndexRangePartition ranges_np = - parallel.RangesOfNP(MMParallel::kMaxPackages, B.Rows(), sizeof_TC, kNR); - const size_t quantum = allocator.Quantum(); - for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { - const IndexRange& rows_b = ranges_np.Range(pkg_idx); - const size_t node = parallel.Node(pkg_idx); - uintptr_t begin = reinterpret_cast(B.Row(rows_b.begin())); - uintptr_t end = begin + rows_b.Num() * B.Stride() * sizeof(TB); - // B is not yet guaranteed to have padded rows, so only bind the - // subset that is page-aligned. - begin = hwy::RoundUpTo(begin, quantum); - end = hwy::RoundDownTo(end, quantum); - if (HWY_LIKELY(begin != end)) { - allocator.BindMemory(reinterpret_cast(begin), end - begin, node); - } - } -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_H_ diff --git a/util/allocator.h b/util/allocator.h index a996d6f..bf904c5 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -98,16 +98,12 @@ class Allocator { // = HWY_MAX(LineBytes(), VectorBytes()) size_t StepBytes() const { return step_bytes_; } - // File size multiple required for memory mapping. + // File size multiple required for memory mapping. Also used when binding + // memory to NUMA nodes (see `BindB/BindC`). size_t BasePageBytes() const { return base_page_bytes_; } - // Either StepBytes or BasePageBytes if NUMA. + // Desired allocator alignment: Either StepBytes, or BasePageBytes if NUMA. size_t QuantumBytes() const { return quantum_bytes_; } - template - // For rounding down elements to the page size in `BindB/BindC`. - size_t Quantum() const { - return QuantumBytes() / sizeof(T); - } // L1 and L2 are typically per core. size_t L1Bytes() const { return l1_bytes_; } diff --git a/util/mat.cc b/util/mat.cc index d40088f..86baaee 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -19,11 +19,9 @@ #include #include -#include #include "util/threading_context.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/per_target.h" // VectorBytes #include "hwy/profiler.h" @@ -126,15 +124,4 @@ void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { mat.SetPtr(storage_.get(), stride); } -void MatOwners::AllocateFor(const std::vector& mats, - MatPadding padding, hwy::ThreadPool& pool) { - const size_t start = owners_.size(); - owners_.resize(start + mats.size()); - - // Allocate in parallel because faulting in large tensors is slow. - pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) { - owners_[start + task].AllocateFor(*mats[task], padding); - }); -} - } // namespace gcpp diff --git a/util/mat.h b/util/mat.h index ce134d1..da662c9 100644 --- a/util/mat.h +++ b/util/mat.h @@ -22,7 +22,6 @@ #include #include -#include // IWYU pragma: begin_exports #include "compression/types.h" // Type @@ -32,7 +31,6 @@ #include "util/basics.h" // Extents2D // IWYU pragma: end_exports #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -376,31 +374,14 @@ class MatOwner { MatOwner& operator=(MatOwner&&) = default; // Allocates the type/extents indicated by `mat` and sets its pointer. + // Ignores `padding` for NUQ tensors, which are always packed. + // Thread-compatible, weights are allocated in parallel. void AllocateFor(MatPtr& mat, MatPadding padding); private: AlignedPtr storage_; }; -// Multiple `MatOwner`, with support for parallel allocation. -class MatOwners { - public: - // Ignores `padding` for NUQ tensors, which are always packed. - void AllocateFor(MatPtr& mat, MatPadding padding) { - if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked; - owners_.push_back(MatOwner()); - owners_.back().AllocateFor(mat, padding); - } - - // Allocates multiple in parallel. Ignores `padding` for NUQ tensors, - // which are always packed. - void AllocateFor(const std::vector& mats, MatPadding padding, - hwy::ThreadPool& pool); - - private: - std::vector owners_; -}; - // `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and // tests to allocate and access tensors of a known type. By contrast, the // heterogeneous model weights are owned by vectors of `MatOwner`.