Refactor Matmul to use a policy class for parallelization.

PiperOrigin-RevId: 800864489
This commit is contained in:
Marie White 2025-08-29 05:40:06 -07:00 committed by Copybara-Service
parent 6c39a2dea4
commit 973e284ed6
4 changed files with 125 additions and 121 deletions

View File

@ -320,8 +320,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
const size_t start = owners.size(); const size_t start = owners.size();
owners.resize(start + tensors.size()); owners.resize(start + tensors.size());
MMParallel parallel(ctx);
// Allocate in parallel because faulting in large tensors is slow. // Allocate in parallel because faulting in large tensors is slow.
ctx.pools.Pool().Run( ctx.pools.Pool().Run(
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) { 0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
@ -339,7 +337,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
tensor.padding); tensor.padding);
BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel); BindB(ctx, *tensor.mat, tensor.mat->ElementBytes());
}); });
} }

View File

@ -779,17 +779,19 @@ class MMPerPackage {
// B and maybe A are decompressed several call layers lower, but not all // B and maybe A are decompressed several call layers lower, but not all
// member functions depend on TA/TB, so pass them as an argument instead of // member functions depend on TA/TB, so pass them as an argument instead of
// templating the class. // templating the class.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_NOINLINE void operator()(const MatPtrT<TA>& A, const MatPtrT<TB>& B, HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const { RowPtrs<TC> C_rows) const {
if constexpr (WantDecompressA<TA>()) { if constexpr (WantDecompressA<TA>()) {
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
DecompressA(A, A_view); DecompressA<MMParallelPolicyT>(A, A_view);
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
DispatchOrder(A_view, A_padded, B, C_rows); DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
} else { } else {
const bool A_padded = HasPadding(A); const bool A_padded = HasPadding(A);
DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows); DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
C_rows);
} }
} }
@ -828,28 +830,30 @@ class MMPerPackage {
} }
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DispatchOrder(const StridedView<TA> A, const bool A_padded, HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const { RowPtrs<TC> C_rows) const {
switch (order_) { switch (order_) {
case MMOrder::kNT: case MMOrder::kNT:
return DoNT(A, A_padded, B, C_rows); return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
case MMOrder::kNT_K: case MMOrder::kNT_K:
return DoNT_K(A, A_padded, B, C_rows); return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
case MMOrder::kNT_MT: case MMOrder::kNT_MT:
return DoNT_MT(A, A_padded, B, C_rows); return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
case MMOrder::kNT_MT_K: case MMOrder::kNT_MT_K:
return DoNT_MT_K(A, A_padded, B, C_rows); return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
} }
// Single M and K ranges, parallel N. Fills all of C directly. // Single M and K ranges, parallel N. Fills all of C directly.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT(const StridedView<TA> A, const bool A_padded, HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -861,9 +865,9 @@ class MMPerPackage {
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
// Similar to `loop_nc` below, but here we hoisted `A_view`. // Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP( MMParallelPolicyT::ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR { pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_); mm_zone.MaybeEnter(worker, zone, args_);
@ -881,9 +885,10 @@ class MMPerPackage {
} }
// Single M range, parallel N, sequential K. Sets C, then accumulates. // Single M range, parallel N, sequential K. Sets C, then accumulates.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded, HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0); const IndexRange& range_mc = ranges_mc_.Range(0);
@ -909,9 +914,9 @@ class MMPerPackage {
} }
}; };
args_.env->parallel.ForNP( MMParallelPolicyT::ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR { pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_); mm_zone.MaybeEnter(worker, zone, args_);
@ -930,9 +935,10 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K. // Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT(const StridedView<TA> A, const bool A_padded, HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0);
@ -942,8 +948,8 @@ class MMPerPackage {
// Sequential loop over NC/MC/KC, similar to `loop_nc` below // Sequential loop over NC/MC/KC, similar to `loop_nc` below
// except for the profiler strings and `out_tag`. // except for the profiler strings and `out_tag`.
args_.env->parallel.ForRangesMC_NC( MMParallelPolicyT::ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_, args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
@ -965,9 +971,10 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded, HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC); HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
@ -992,8 +999,8 @@ class MMPerPackage {
out_tag, args_, C_rows); out_tag, args_, C_rows);
} }
}; // loop_nc }; // loop_nc
args_.env->parallel.ForRangesMC_NC( MMParallelPolicyT::ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_, args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
@ -1014,6 +1021,7 @@ class MMPerPackage {
} }
// Decompresses all `M x K` from `A` into padded BF16 `A_view`. // Decompresses all `M x K` from `A` into padded BF16 `A_view`.
template <typename MMParallelPolicyT>
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A, HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view, const StridedViewBF A_view,
MMParA par_a) const { MMParA par_a) const {
@ -1064,16 +1072,16 @@ class MMPerPackage {
// line to avoid false sharing. // line to avoid false sharing.
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
args_.env->parallel.ForNP( MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
all_K, multiple_K, inner_tasks, pkg_idx_, pkg_idx_,
[&](const IndexRange& range_K, size_t worker) { [&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker); do_range(all_M, range_K, worker);
}); });
break; break;
} }
case MMParA::kM: case MMParA::kM:
args_.env->parallel.ForRangeMC( MMParallelPolicyT::ForRangeMC(
all_M, pkg_idx_, [&](size_t row_a, size_t worker) { args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K, worker); do_range(IndexRange(row_a, row_a + 1), all_K, worker);
}); });
break; break;
@ -1081,12 +1089,13 @@ class MMPerPackage {
} }
// Autotuning wrapper for `DoDecompressA`. // Autotuning wrapper for `DoDecompressA`.
template <typename MMParallelPolicyT>
HWY_INLINE void DecompressA(const MatPtrT<float>& A, HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view) const { const StridedViewBF A_view) const {
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_]; MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
if (HWY_LIKELY(autotune.Best())) { if (HWY_LIKELY(autotune.Best())) {
return DoDecompressA(A, A_view, *autotune.Best()); return DoDecompressA<MMParallelPolicyT>(A, A_view, *autotune.Best());
} }
// First call: generate candidates. // First call: generate candidates.
@ -1099,7 +1108,7 @@ class MMPerPackage {
const MMParA& par_a = autotune.NextConfig(); const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
DoDecompressA(A, A_view, par_a); DoDecompressA<MMParallelPolicyT>(A, A_view, par_a);
const uint64_t t1 = const uint64_t t1 =
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
@ -1185,19 +1194,21 @@ struct MMImpl {
if constexpr (kMaxPackages > 1) { if constexpr (kMaxPackages > 1) {
// Outermost loop: static NUMA-aware partition of B rows across packages. // Outermost loop: static NUMA-aware partition of B rows across packages.
args.env->parallel.ForPkg( MMNestedParallelPolicy::ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { args.env->ctx, args.per_key->ranges_np.NumTasks(),
[&](size_t pkg_idx) {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(pkg_idx, zone, args); mm_zone.MaybeEnter(pkg_idx, zone, args);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
C_rows); MMNestedParallelPolicy(), A, B, C_rows);
}); });
} else { } else {
const size_t pkg_idx = 0; const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows); MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
} }
} }
}; };
@ -1250,8 +1261,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// invalidates `MMAutoTune::Best()` // invalidates `MMAutoTune::Best()`
index = env.per_key.size(); index = env.per_key.size();
env.per_key.push_back( env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
} }
MMPerKey& per_key = env.per_key[index]; MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune; MMAutoTune<MMConfig>& tuner = per_key.autotune;

View File

@ -397,34 +397,33 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
return np_multiple; return np_multiple;
} }
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t sizeof_TC, size_t nr) const { size_t N, size_t sizeof_TC, size_t nr) {
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages()); const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages());
return StaticPartition( return StaticPartition(
IndexRange(0, N), num_packages, IndexRange(0, N), num_packages,
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages)); NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
} }
MatMulEnv::MatMulEnv(ThreadingContext& ctx) MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
char cpu100[100]; char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100); have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
} }
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) { void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
Allocator& allocator = parallel.allocator(); Allocator& allocator = ctx.allocator;
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;
if (B.Rows() == 1) return; if (B.Rows() == 1) return;
PROFILER_ZONE("Startup.BindB"); PROFILER_ZONE("Startup.BindB");
const IndexRangePartition ranges_np = const IndexRangePartition ranges_np =
parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR); MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx); const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx); const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin())); uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that // B row padding is less than the page size, so only bind the subset that
@ -438,14 +437,14 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
} }
// C is BF16/float // C is BF16/float
void BindC(MatPtr& C, MMParallel& parallel) { void BindC(ThreadingContext& ctx, MatPtr& C) {
Allocator& allocator = parallel.allocator(); Allocator& allocator = ctx.allocator;
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;
PROFILER_ZONE("Startup.BindC"); PROFILER_ZONE("Startup.BindC");
const IndexRangePartition ranges_np = const IndexRangePartition ranges_np =
parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR); MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
bool ok = true; bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx); const IndexRange& cols_c = ranges_np.Range(pkg_idx);
@ -455,7 +454,7 @@ void BindC(MatPtr& C, MMParallel& parallel) {
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
allocator.BasePageBytes()); allocator.BasePageBytes());
const size_t node = parallel.Node(pkg_idx); const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
for (size_t im = 0; im < C.Rows(); ++im) { for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
} }

View File

@ -53,35 +53,31 @@ constexpr size_t kNR = 4;
// or less on ISAs with fewer registers, or for the last few rows of A. // or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4; static constexpr size_t kMaxMR = 4;
// Mostly stateless, can be constructed on the fly by weights.cc. Captures the IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
// the ThreadingContext to shorten call sites. size_t N, size_t sizeof_TC, size_t nr);
class MMParallel {
public:
// `ctx` must outlive this object.
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
if (ctx_.pools.NumPackages() > kMaxPackages) {
HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.",
ctx_.pools.NumPackages(), kMaxPackages);
}
}
Allocator& allocator() const { return ctx_.allocator; } enum class ParallelismType : uint8_t {
kNone,
// No parallelism.
kSequential,
// Parallelism at cluster level.
kCluster,
// Parallelism at package level.
kNested,
};
// Initial static partitioning of B rows across packages. struct MMOptions {
IndexRangePartition RangesOfNP(size_t max_packages, size_t N, ParallelismType parallelism_type_ = ParallelismType::kNested;
size_t sizeof_TC, size_t nr) const; uint8_t cluster_idx_ = 0;
};
// For `BindB` and `BindC`. struct MMNestedParallelPolicy {
size_t Node(size_t pkg_idx) const {
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
}
// Calls `func(pkg_idx)` for each package in parallel.
template <class Func> template <class Func>
void ForPkg(const size_t max_packages, const Func& func) { static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
const Func& func) {
if constexpr (kMaxPackages > 1) { if constexpr (kMaxPackages > 1) {
ctx_.pools.AllPackages().Run( ctx.pools.AllPackages().Run(
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()), 0, HWY_MIN(max_packages, ctx.pools.NumPackages()),
[&](uint64_t task, size_t pkg_idx) { [&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx); HWY_DASSERT(task == pkg_idx);
(void)task; (void)task;
@ -95,16 +91,17 @@ class MMParallel {
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
template <class Func> template <class Func>
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks, static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t pkg_idx, const Func& func) { size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
// Single cluster: parallel-for over static partition of `range_np`. // Single cluster: parallel-for over static partition of `range_np`.
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) { if (num_clusters == 1) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0);
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
return ParallelizeOneRange( return ParallelizeOneRange(
@ -120,9 +117,9 @@ class MMParallel {
ParallelizeOneRange( ParallelizeOneRange(
nx_ranges, all_clusters, nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) { [&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base = const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster(); pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
// Parallel-for over sub-ranges of `cluster_range` within the cluster. // Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition( const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
@ -137,18 +134,19 @@ class MMParallel {
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B // Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc, worker)`. // rows). Calls `func(range_mc, range_nc, worker)`.
template <class Func> template <class Func>
void ForRangesMC_NC(const IndexRangePartition& ranges_mc, static void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_nc, size_t pkg_idx, const IndexRangePartition& ranges_mc,
const Func& func) { const IndexRangePartition& ranges_nc,
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); size_t pkg_idx, const Func& func) {
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package. // `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers(); const size_t num_clusters = all_clusters.NumWorkers();
// Single (big) cluster: collapse two range indices into one parallel-for // Single (big) cluster: collapse two range indices into one parallel-for
// to reduce the number of fork-joins. // to reduce the number of fork-joins.
if (num_clusters == 1) { if (num_clusters == 1) {
const size_t cluster_idx = 0; const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
// Low-batch: avoid Divide/Remainder. // Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange( return ParallelizeOneRange(
@ -171,8 +169,8 @@ class MMParallel {
ranges_nc, all_clusters, ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) { [&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base = const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster(); pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange(ranges_mc, cluster, ParallelizeOneRange(ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t thread) { [&](const IndexRange& range_mc, size_t thread) {
func(range_mc, range_nc, cluster_base + thread); func(range_mc, range_nc, cluster_base + thread);
@ -182,21 +180,18 @@ class MMParallel {
// Calls `func(row_a, worker)` in parallel. // Calls `func(row_a, worker)` in parallel.
template <class Func> template <class Func>
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx, static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
const Func& func) { size_t pkg_idx, const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage(); const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
ctx_.pools.Pool(pkg_idx).Run( ctx.pools.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(), range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); }); [&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
} }
private:
ThreadingContext& ctx_;
}; };
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
// C is BF16/float. // C is BF16/float.
void BindC(MatPtr& C, MMParallel& parallel); void BindC(ThreadingContext& ctx, MatPtr& C);
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
#pragma pack(push, 1) // power of two size #pragma pack(push, 1) // power of two size
@ -250,15 +245,18 @@ class MMStorage {
// Internally threaded; must not be called concurrently with the same // Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`). // `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel) { MMStorage(ThreadingContext& ctx) {
// Per-package allocation so each can decompress A into its own copy. // Per-package allocation so each can decompress A into its own copy.
// Must be padded, see `DoDecompressA`. // Must be padded, see `DoDecompressA`.
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) { // Default to nested parallel policy.
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
Allocator& allocator = ctx.allocator;
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>( pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) { if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx); const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
pkg_A_[pkg_idx]->ElementBytes(); pkg_A_[pkg_idx]->ElementBytes();
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
@ -607,9 +605,9 @@ class MMKeys {
// Per-MatMul-shape state. // Per-MatMul-shape state.
struct MMPerKey { struct MMPerKey {
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr, MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N,
MMParallel& parallel) size_t sizeof_TC, size_t nr)
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) { : ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) {
HWY_DASSERT(ranges_np.NumTasks() <= max_packages); HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
} }
@ -639,7 +637,6 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished. // Whether to print the best config immediately after autotuning finished.
bool print_best = false; bool print_best = false;
MMParallel parallel;
MMStorage storage; MMStorage storage;
MMKeys keys; MMKeys keys;
std::vector<MMPerKey> per_key; std::vector<MMPerKey> per_key;