From 973e284ed6ee0595de90a399a52a869542563fc7 Mon Sep 17 00:00:00 2001 From: Marie White Date: Fri, 29 Aug 2025 05:40:06 -0700 Subject: [PATCH] Refactor Matmul to use a policy class for parallelization. PiperOrigin-RevId: 800864489 --- gemma/weights.cc | 4 +- ops/matmul-inl.h | 108 ++++++++++++++++++++++++++--------------------- ops/matmul.cc | 27 ++++++------ ops/matmul.h | 107 +++++++++++++++++++++++----------------------- 4 files changed, 125 insertions(+), 121 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 4124247..3425a60 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -320,8 +320,6 @@ static void AllocateAndBindAll(std::vector& tensors, const size_t start = owners.size(); owners.resize(start + tensors.size()); - MMParallel parallel(ctx); - // Allocate in parallel because faulting in large tensors is slow. ctx.pools.Pool().Run( 0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { @@ -339,7 +337,7 @@ static void AllocateAndBindAll(std::vector& tensors, owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, tensor.padding); - BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel); + BindB(ctx, *tensor.mat, tensor.mat->ElementBytes()); }); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 56cb06f..29be665 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -779,17 +779,19 @@ class MMPerPackage { // B and maybe A are decompressed several call layers lower, but not all // member functions depend on TA/TB, so pass them as an argument instead of // templating the class. - template - HWY_NOINLINE void operator()(const MatPtrT& A, const MatPtrT& B, + template + HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, + const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows) const { if constexpr (WantDecompressA()) { const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); - DecompressA(A, A_view); + DecompressA(A, A_view); constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. - DispatchOrder(A_view, A_padded, B, C_rows); + DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows); } else { const bool A_padded = HasPadding(A); - DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows); + DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B, + C_rows); } } @@ -828,28 +830,30 @@ class MMPerPackage { } // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. - template - HWY_INLINE void DispatchOrder(const StridedView A, const bool A_padded, + template + HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, + const StridedView A, const bool A_padded, const MatPtrT& B, RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(A, A_padded, B, C_rows); + return DoNT(parallel_policy, A, A_padded, B, C_rows); case MMOrder::kNT_K: - return DoNT_K(A, A_padded, B, C_rows); + return DoNT_K(parallel_policy, A, A_padded, B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(A, A_padded, B, C_rows); + return DoNT_MT(parallel_policy, A, A_padded, B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(A, A_padded, B, C_rows); + return DoNT_MT_K(parallel_policy, A, A_padded, B, C_rows); default: HWY_UNREACHABLE; } } // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -861,9 +865,9 @@ class MMPerPackage { Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); // Similar to `loop_nc` below, but here we hoisted `A_view`. - args_.env->parallel.ForNP( - range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, - [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + MMParallelPolicyT::ForNP( + args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -881,9 +885,10 @@ class MMPerPackage { } // Single M range, parallel N, sequential K. Sets C, then accumulates. - template - HWY_INLINE void DoNT_K(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); const IndexRange& range_mc = ranges_mc_.Range(0); @@ -909,9 +914,9 @@ class MMPerPackage { } }; - args_.env->parallel.ForNP( - range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, - [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + MMParallelPolicyT::ForNP( + args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -930,9 +935,10 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); @@ -942,8 +948,8 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. - args_.env->parallel.ForRangesMC_NC( - ranges_mc_, ranges_nc_, pkg_idx_, + MMParallelPolicyT::ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -965,9 +971,10 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(const StridedView A, const bool A_padded, - const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView A, + const bool A_padded, const MatPtrT& B, + RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); @@ -992,8 +999,8 @@ class MMPerPackage { out_tag, args_, C_rows); } }; // loop_nc - args_.env->parallel.ForRangesMC_NC( - ranges_mc_, ranges_nc_, pkg_idx_, + MMParallelPolicyT::ForRangesMC_NC( + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -1014,6 +1021,7 @@ class MMPerPackage { } // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + template HWY_NOINLINE void DoDecompressA(const MatPtrT& A, const StridedViewBF A_view, MMParA par_a) const { @@ -1064,16 +1072,16 @@ class MMPerPackage { // line to avoid false sharing. const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); - args_.env->parallel.ForNP( - all_K, multiple_K, inner_tasks, pkg_idx_, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); + MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks, + pkg_idx_, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); break; } case MMParA::kM: - args_.env->parallel.ForRangeMC( - all_M, pkg_idx_, [&](size_t row_a, size_t worker) { + MMParallelPolicyT::ForRangeMC( + args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) { do_range(IndexRange(row_a, row_a + 1), all_K, worker); }); break; @@ -1081,12 +1089,13 @@ class MMPerPackage { } // Autotuning wrapper for `DoDecompressA`. + template HWY_INLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view) const { MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, *autotune.Best()); + return DoDecompressA(A, A_view, *autotune.Best()); } // First call: generate candidates. @@ -1099,7 +1108,7 @@ class MMPerPackage { const MMParA& par_a = autotune.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, par_a); + DoDecompressA(A, A_view, par_a); const uint64_t t1 = args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); @@ -1185,19 +1194,21 @@ struct MMImpl { if constexpr (kMaxPackages > 1) { // Outermost loop: static NUMA-aware partition of B rows across packages. - args.env->parallel.ForPkg( - args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { + MMNestedParallelPolicy::ForPkg( + args.env->ctx, args.per_key->ranges_np.NumTasks(), + [&](size_t pkg_idx) { MMZone mm_zone; mm_zone.MaybeEnter(pkg_idx, zone, args); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, - C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMNestedParallelPolicy(), A, B, C_rows); }); } else { const size_t pkg_idx = 0; HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMNestedParallelPolicy(), A, B, C_rows); } } }; @@ -1250,8 +1261,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // invalidates `MMAutoTune::Best()` index = env.per_key.size(); - env.per_key.push_back( - MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel)); + env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); } MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; diff --git a/ops/matmul.cc b/ops/matmul.cc index 71f2efe..711eac1 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -397,34 +397,33 @@ static size_t NPMultiple(const Allocator& allocator, size_t N, return np_multiple; } -IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, - size_t sizeof_TC, size_t nr) const { - const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages()); +IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, + size_t N, size_t sizeof_TC, size_t nr) { + const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages()); return StaticPartition( IndexRange(0, N), num_packages, - NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages)); + NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages)); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) - : ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) { +MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) { char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C } -void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { - Allocator& allocator = parallel.allocator(); +void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { + Allocator& allocator = ctx.allocator; if (!allocator.ShouldBind()) return; if (B.Rows() == 1) return; PROFILER_ZONE("Startup.BindB"); const IndexRangePartition ranges_np = - parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR); + MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR); for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); - const size_t node = parallel.Node(pkg_idx); + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); uintptr_t begin = reinterpret_cast(B.RowBytes(rows_b.begin())); uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); // B row padding is less than the page size, so only bind the subset that @@ -438,14 +437,14 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { } // C is BF16/float -void BindC(MatPtr& C, MMParallel& parallel) { - Allocator& allocator = parallel.allocator(); +void BindC(ThreadingContext& ctx, MatPtr& C) { + Allocator& allocator = ctx.allocator; if (!allocator.ShouldBind()) return; PROFILER_ZONE("Startup.BindC"); const IndexRangePartition ranges_np = - parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR); + MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR); bool ok = true; for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& cols_c = ranges_np.Range(pkg_idx); @@ -455,7 +454,7 @@ void BindC(MatPtr& C, MMParallel& parallel) { const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), allocator.BasePageBytes()); - const size_t node = parallel.Node(pkg_idx); + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); for (size_t im = 0; im < C.Rows(); ++im) { ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); } diff --git a/ops/matmul.h b/ops/matmul.h index e4c436f..16028f3 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -53,35 +53,31 @@ constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. static constexpr size_t kMaxMR = 4; -// Mostly stateless, can be constructed on the fly by weights.cc. Captures the -// the ThreadingContext to shorten call sites. -class MMParallel { - public: - // `ctx` must outlive this object. - MMParallel(ThreadingContext& ctx) : ctx_(ctx) { - if (ctx_.pools.NumPackages() > kMaxPackages) { - HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.", - ctx_.pools.NumPackages(), kMaxPackages); - } - } +IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, + size_t N, size_t sizeof_TC, size_t nr); - Allocator& allocator() const { return ctx_.allocator; } +enum class ParallelismType : uint8_t { + kNone, + // No parallelism. + kSequential, + // Parallelism at cluster level. + kCluster, + // Parallelism at package level. + kNested, +}; - // Initial static partitioning of B rows across packages. - IndexRangePartition RangesOfNP(size_t max_packages, size_t N, - size_t sizeof_TC, size_t nr) const; +struct MMOptions { + ParallelismType parallelism_type_ = ParallelismType::kNested; + uint8_t cluster_idx_ = 0; +}; - // For `BindB` and `BindC`. - size_t Node(size_t pkg_idx) const { - return ctx_.topology.GetCluster(pkg_idx, 0).Node(); - } - - // Calls `func(pkg_idx)` for each package in parallel. +struct MMNestedParallelPolicy { template - void ForPkg(const size_t max_packages, const Func& func) { + static void ForPkg(ThreadingContext& ctx, const size_t max_packages, + const Func& func) { if constexpr (kMaxPackages > 1) { - ctx_.pools.AllPackages().Run( - 0, HWY_MIN(max_packages, ctx_.pools.NumPackages()), + ctx.pools.AllPackages().Run( + 0, HWY_MIN(max_packages, ctx.pools.NumPackages()), [&](uint64_t task, size_t pkg_idx) { HWY_DASSERT(task == pkg_idx); (void)task; @@ -95,16 +91,17 @@ class MMParallel { // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. template - void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, - size_t pkg_idx, const Func& func) { + static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, + const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); + const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); // Single cluster: parallel-for over static partition of `range_np`. - hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0); const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); return ParallelizeOneRange( @@ -120,9 +117,9 @@ class MMParallel { ParallelizeOneRange( nx_ranges, all_clusters, [&](const IndexRange& nx_range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = - pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster(); + pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); @@ -137,18 +134,19 @@ class MMParallel { // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // rows). Calls `func(range_mc, range_nc, worker)`. template - void ForRangesMC_NC(const IndexRangePartition& ranges_mc, - const IndexRangePartition& ranges_nc, size_t pkg_idx, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); - hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); + static void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t pkg_idx, const Func& func) { + const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. const size_t num_clusters = all_clusters.NumWorkers(); // Single (big) cluster: collapse two range indices into one parallel-for // to reduce the number of fork-joins. if (num_clusters == 1) { const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( @@ -171,8 +169,8 @@ class MMParallel { ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { const size_t cluster_base = - pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster(); - hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); + pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster(); + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange(ranges_mc, cluster, [&](const IndexRange& range_mc, size_t thread) { func(range_mc, range_nc, cluster_base + thread); @@ -182,21 +180,18 @@ class MMParallel { // Calls `func(row_a, worker)` in parallel. template - void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx, - const Func& func) { - const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); - ctx_.pools.Pool(pkg_idx).Run( + static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t pkg_idx, const Func& func) { + const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + ctx.pools.Pool(pkg_idx).Run( range_mc.begin(), range_mc.end(), [&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); }); } - - private: - ThreadingContext& ctx_; }; -void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); +void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); // C is BF16/float. -void BindC(MatPtr& C, MMParallel& parallel); +void BindC(ThreadingContext& ctx, MatPtr& C); // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. #pragma pack(push, 1) // power of two size @@ -250,15 +245,18 @@ class MMStorage { // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). - MMStorage(const Allocator& allocator, MMParallel& parallel) { + MMStorage(ThreadingContext& ctx) { // Per-package allocation so each can decompress A into its own copy. // Must be padded, see `DoDecompressA`. - parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) { + // Default to nested parallel policy. + MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { + Allocator& allocator = ctx.allocator; + pkg_A_[pkg_idx].reset(new MatStorageT( "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); if (allocator.ShouldBind()) { - const size_t node = parallel.Node(pkg_idx); + const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * pkg_A_[pkg_idx]->ElementBytes(); bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); @@ -607,9 +605,9 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr, - MMParallel& parallel) - : ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) { + MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N, + size_t sizeof_TC, size_t nr) + : ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) { HWY_DASSERT(ranges_np.NumTasks() <= max_packages); } @@ -639,7 +637,6 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - MMParallel parallel; MMStorage storage; MMKeys keys; std::vector per_key;