diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index b54ce05..8f91114 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1135,28 +1135,23 @@ struct MMImpl { template static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows, const MMArgs& args, - const MMConfig& config) { + const MMConfig& config, + ParallelismType parallelism_type) { PROFILER_ZONE("MM.DoMatMul"); - static const auto zone = - args.env->ctx.profiler.AddZone("MM.DoMatMul.PerPkg"); + const size_t pkg_idx = 0; + HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); + const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - if constexpr (kMaxPackages > 1) { - // Outermost loop: static NUMA-aware partition of B rows across packages. - MMNestedParallelPolicy::ForPkg( - args.env->ctx, args.per_key->ranges_np.NumTasks(), - [&](size_t pkg_idx) { - MMZone mm_zone; - mm_zone.MaybeEnter(pkg_idx, zone, args); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMNestedParallelPolicy(), A, B, C_rows); - }); - } else { - const size_t pkg_idx = 0; - HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); - const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( - MMNestedParallelPolicy(), A, B, C_rows); + switch (parallelism_type) { + case ParallelismType::kNested: + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)( + MMNestedParallelPolicy(), A, B, C_rows); + break; + case ParallelismType::kNone: + case ParallelismType::kSequential: + case ParallelismType::kCluster: + HWY_ABORT("Parallelism type not implemented."); + break; } } }; @@ -1210,10 +1205,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; + // Default to nested parallelism. + const ParallelismType parallelism_type = ParallelismType::kNested; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best()); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type); return &per_key; } @@ -1242,7 +1239,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C_rows, args, cfg); + MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) /