diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 2e4dcde..1d5dc5d 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -728,9 +728,10 @@ class MMPerPackage { public: MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, - size_t pkg_idx, const IndexRange& range_np) + size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) : args_(args), pkg_idx_(pkg_idx), + cluster_idx_(cluster_idx), range_np_(range_np), mr_(config.MR()), ranges_mc_(config.RangesOfMC(A.rows)), @@ -821,7 +822,8 @@ class MMPerPackage { // Similar to `loop_nc` below, but here we hoisted `A_view`. MMParallelPolicyT::ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + pkg_idx_, cluster_idx_, + [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -869,7 +871,8 @@ class MMPerPackage { MMParallelPolicyT::ForNP( args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, - pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + pkg_idx_, cluster_idx_, + [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); @@ -901,7 +904,7 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -951,7 +954,7 @@ class MMPerPackage { } }; // loop_nc MMParallelPolicyT::ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, + args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -1024,7 +1027,7 @@ class MMPerPackage { const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks, - pkg_idx_, + pkg_idx_, cluster_idx_, [&](const IndexRange& range_K, size_t worker) { do_range(all_M, range_K, worker); }); @@ -1032,7 +1035,8 @@ class MMPerPackage { } case MMParA::kM: MMParallelPolicyT::ForRangeMC( - args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) { + args_.env->ctx, all_M, pkg_idx_, cluster_idx_, + [&](size_t row_a, size_t worker) { do_range(IndexRange(row_a, row_a + 1), all_K, worker); }); break; @@ -1106,6 +1110,7 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; + const size_t cluster_idx_; // 0 for sequential and nested. const IndexRange range_np_; // From MMConfig: @@ -1135,23 +1140,26 @@ struct MMImpl { template static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows, const MMArgs& args, - const MMConfig& config, - ParallelismType parallelism_type) { + const MMConfig& config, MMOptions options) { PROFILER_ZONE("MM.DoMatMul"); const size_t pkg_idx = 0; HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - switch (parallelism_type) { + switch (options.parallelism_type) { case ParallelismType::kNested: - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMNestedParallelPolicy(), A, B, C_rows); + HWY_DASSERT(options.cluster_idx == 0); + MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, + range_np)(MMNestedParallelPolicy(), A, B, C_rows); break; case ParallelismType::kSequential: - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMSequentialPolicy(), A, B, C_rows); - case ParallelismType::kNone: + MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, + range_np)(MMSequentialPolicy(), A, B, C_rows); case ParallelismType::kCluster: + MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, + range_np)(MMClusterParallelPolicy(), A, B, C_rows); + break; + default: HWY_ABORT("Parallelism type not implemented."); break; } @@ -1210,8 +1218,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), - options.parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options); return &per_key; } @@ -1240,7 +1247,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type); + MMImpl::DoMatMul(A, B, C_rows, args, cfg, options); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / diff --git a/ops/matmul.h b/ops/matmul.h index 620d382..5c526de 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -81,9 +81,10 @@ struct MMSequentialPolicy { template static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - const Func& func) { + size_t cluster_idx, const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + + cluster_idx * ctx.pools.MaxWorkersPerCluster(); func(range_np, base_idx); } @@ -91,8 +92,10 @@ struct MMSequentialPolicy { static void ForRangesMC_NC(ThreadingContext& ctx, const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, - size_t pkg_idx, const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + size_t pkg_idx, size_t cluster_idx, + const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + + cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { const IndexRange range_mc = ranges_mc.Range(i); @@ -105,14 +108,68 @@ struct MMSequentialPolicy { template static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, const Func& func) { - const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + size_t pkg_idx, size_t cluster_idx, const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() + + cluster_idx * ctx.pools.MaxWorkersPerCluster(); for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { func(row_a, base_idx); } } }; +struct MMClusterParallelPolicy { + template + static void ForPkg(ThreadingContext& ctx, const size_t max_packages, + const Func& func) { + func(/*pkg_idx=*/0); + } + + template + static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, + size_t cluster_idx, const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + const IndexRangePartition worker_ranges = StaticPartition( + range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + ParallelizeOneRange(worker_ranges, cluster, + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, worker); + }); + } + + template + static void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t pkg_idx, size_t cluster_idx, + const Func& func) { + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + + // Low-batch: avoid Divide/Remainder. + if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { + ParallelizeOneRange(ranges_nc, cluster, + [&](const IndexRange& range_nc, size_t thread) { + func(ranges_mc.Range(0), range_nc, thread); + }); + } else { + ParallelizeTwoRanges( + ranges_mc, ranges_nc, cluster, + [&](const IndexRange& range_mc, const IndexRange& range_nc, + size_t thread) { func(range_mc, range_nc, thread); }); + } + } + + template + static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t pkg_idx, size_t cluster_idx, const Func& func) { + hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); + cluster.Run(range_mc.begin(), range_mc.end(), + [&](uint64_t row_a, size_t thread) { func(row_a, thread); }); + } +}; + struct MMNestedParallelPolicy { template static void ForPkg(ThreadingContext& ctx, const size_t max_packages, @@ -132,10 +189,11 @@ struct MMNestedParallelPolicy { // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. + // `cluster_idx` is not used here as all clusters within a package are used. template static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, - const Func& func) { + size_t /*cluster_idx*/, const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); @@ -175,11 +233,13 @@ struct MMNestedParallelPolicy { // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // rows). Calls `func(range_mc, range_nc, worker)`. + // `cluster_idx` is not used here as all clusters within a package are used. template static void ForRangesMC_NC(ThreadingContext& ctx, const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, - size_t pkg_idx, const Func& func) { + size_t pkg_idx, size_t /*cluster_idx*/, + const Func& func) { const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. @@ -221,9 +281,11 @@ struct MMNestedParallelPolicy { } // Calls `func(row_a, worker)` in parallel. + // `cluster_idx` is not used here as all clusters within a package are used. template static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, - size_t pkg_idx, const Func& func) { + size_t pkg_idx, size_t /*cluster_idx*/, + const Func& func) { const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage(); ctx.pools.Pool(pkg_idx).Run( range_mc.begin(), range_mc.end(),