Refactor Matmul to use a policy class for parallelization.

PiperOrigin-RevId: 800864489
This commit is contained in:
Marie White 2025-08-29 05:40:06 -07:00 committed by Copybara-Service
parent 6c39a2dea4
commit 973e284ed6
4 changed files with 125 additions and 121 deletions

View File

@ -320,8 +320,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
const size_t start = owners.size();
owners.resize(start + tensors.size());
MMParallel parallel(ctx);
// Allocate in parallel because faulting in large tensors is slow.
ctx.pools.Pool().Run(
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
@ -339,7 +337,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
tensor.padding);
BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel);
BindB(ctx, *tensor.mat, tensor.mat->ElementBytes());
});
}

View File

@ -779,17 +779,19 @@ class MMPerPackage {
// B and maybe A are decompressed several call layers lower, but not all
// member functions depend on TA/TB, so pass them as an argument instead of
// templating the class.
template <typename TA, typename TB, typename TC>
HWY_NOINLINE void operator()(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
if constexpr (WantDecompressA<TA>()) {
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
DecompressA(A, A_view);
DecompressA<MMParallelPolicyT>(A, A_view);
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
DispatchOrder(A_view, A_padded, B, C_rows);
DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
} else {
const bool A_padded = HasPadding(A);
DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows);
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
C_rows);
}
}
@ -828,28 +830,30 @@ class MMPerPackage {
}
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DispatchOrder(const StridedView<TA> A, const bool A_padded,
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
switch (order_) {
case MMOrder::kNT:
return DoNT(A, A_padded, B, C_rows);
return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
case MMOrder::kNT_K:
return DoNT_K(A, A_padded, B, C_rows);
return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
case MMOrder::kNT_MT:
return DoNT_MT(A, A_padded, B, C_rows);
return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
case MMOrder::kNT_MT_K:
return DoNT_MT_K(A, A_padded, B, C_rows);
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
default:
HWY_UNREACHABLE;
}
}
// Single M and K ranges, parallel N. Fills all of C directly.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -861,9 +865,9 @@ class MMPerPackage {
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
// Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMParallelPolicyT::ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
@ -881,9 +885,10 @@ class MMPerPackage {
}
// Single M range, parallel N, sequential K. Sets C, then accumulates.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0);
@ -909,9 +914,9 @@ class MMPerPackage {
}
};
args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMParallelPolicyT::ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
@ -930,9 +935,10 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_MT(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0);
@ -942,8 +948,8 @@ class MMPerPackage {
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
// except for the profiler strings and `out_tag`.
args_.env->parallel.ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_,
MMParallelPolicyT::ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
@ -965,9 +971,10 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
@ -992,8 +999,8 @@ class MMPerPackage {
out_tag, args_, C_rows);
}
}; // loop_nc
args_.env->parallel.ForRangesMC_NC(
ranges_mc_, ranges_nc_, pkg_idx_,
MMParallelPolicyT::ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
@ -1014,6 +1021,7 @@ class MMPerPackage {
}
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
template <typename MMParallelPolicyT>
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMParA par_a) const {
@ -1064,16 +1072,16 @@ class MMPerPackage {
// line to avoid false sharing.
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
args_.env->parallel.ForNP(
all_K, multiple_K, inner_tasks, pkg_idx_,
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
pkg_idx_,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
});
break;
}
case MMParA::kM:
args_.env->parallel.ForRangeMC(
all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
MMParallelPolicyT::ForRangeMC(
args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
});
break;
@ -1081,12 +1089,13 @@ class MMPerPackage {
}
// Autotuning wrapper for `DoDecompressA`.
template <typename MMParallelPolicyT>
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view) const {
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
if (HWY_LIKELY(autotune.Best())) {
return DoDecompressA(A, A_view, *autotune.Best());
return DoDecompressA<MMParallelPolicyT>(A, A_view, *autotune.Best());
}
// First call: generate candidates.
@ -1099,7 +1108,7 @@ class MMPerPackage {
const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start();
DoDecompressA(A, A_view, par_a);
DoDecompressA<MMParallelPolicyT>(A, A_view, par_a);
const uint64_t t1 =
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
@ -1185,19 +1194,21 @@ struct MMImpl {
if constexpr (kMaxPackages > 1) {
// Outermost loop: static NUMA-aware partition of B rows across packages.
args.env->parallel.ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
MMNestedParallelPolicy::ForPkg(
args.env->ctx, args.per_key->ranges_np.NumTasks(),
[&](size_t pkg_idx) {
MMZone mm_zone;
mm_zone.MaybeEnter(pkg_idx, zone, args);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B,
C_rows);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
});
} else {
const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
}
}
};
@ -1250,8 +1261,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// invalidates `MMAutoTune::Best()`
index = env.per_key.size();
env.per_key.push_back(
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
}
MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune;

View File

@ -397,34 +397,33 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
return np_multiple;
}
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr) const {
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t N, size_t sizeof_TC, size_t nr) {
const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages());
return StaticPartition(
IndexRange(0, N), num_packages,
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
}
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
}
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
Allocator& allocator = ctx.allocator;
if (!allocator.ShouldBind()) return;
if (B.Rows() == 1) return;
PROFILER_ZONE("Startup.BindB");
const IndexRangePartition ranges_np =
parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR);
MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that
@ -438,14 +437,14 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
}
// C is BF16/float
void BindC(MatPtr& C, MMParallel& parallel) {
Allocator& allocator = parallel.allocator();
void BindC(ThreadingContext& ctx, MatPtr& C) {
Allocator& allocator = ctx.allocator;
if (!allocator.ShouldBind()) return;
PROFILER_ZONE("Startup.BindC");
const IndexRangePartition ranges_np =
parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
@ -455,7 +454,7 @@ void BindC(MatPtr& C, MMParallel& parallel) {
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
allocator.BasePageBytes());
const size_t node = parallel.Node(pkg_idx);
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
}

View File

@ -53,35 +53,31 @@ constexpr size_t kNR = 4;
// or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4;
// Mostly stateless, can be constructed on the fly by weights.cc. Captures the
// the ThreadingContext to shorten call sites.
class MMParallel {
public:
// `ctx` must outlive this object.
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
if (ctx_.pools.NumPackages() > kMaxPackages) {
HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.",
ctx_.pools.NumPackages(), kMaxPackages);
}
}
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t N, size_t sizeof_TC, size_t nr);
Allocator& allocator() const { return ctx_.allocator; }
enum class ParallelismType : uint8_t {
kNone,
// No parallelism.
kSequential,
// Parallelism at cluster level.
kCluster,
// Parallelism at package level.
kNested,
};
// Initial static partitioning of B rows across packages.
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr) const;
struct MMOptions {
ParallelismType parallelism_type_ = ParallelismType::kNested;
uint8_t cluster_idx_ = 0;
};
// For `BindB` and `BindC`.
size_t Node(size_t pkg_idx) const {
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
}
// Calls `func(pkg_idx)` for each package in parallel.
struct MMNestedParallelPolicy {
template <class Func>
void ForPkg(const size_t max_packages, const Func& func) {
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
const Func& func) {
if constexpr (kMaxPackages > 1) {
ctx_.pools.AllPackages().Run(
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
ctx.pools.AllPackages().Run(
0, HWY_MIN(max_packages, ctx.pools.NumPackages()),
[&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx);
(void)task;
@ -95,16 +91,17 @@ class MMParallel {
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
template <class Func>
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
size_t pkg_idx, const Func& func) {
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
// Single cluster: parallel-for over static partition of `range_np`.
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0);
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
return ParallelizeOneRange(
@ -120,9 +117,9 @@ class MMParallel {
ParallelizeOneRange(
nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
@ -137,18 +134,19 @@ class MMParallel {
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc, worker)`.
template <class Func>
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t pkg_idx,
const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
static void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
size_t pkg_idx, const Func& func) {
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers();
// Single (big) cluster: collapse two range indices into one parallel-for
// to reduce the number of fork-joins.
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange(
@ -171,8 +169,8 @@ class MMParallel {
ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base =
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange(ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t thread) {
func(range_mc, range_nc, cluster_base + thread);
@ -182,21 +180,18 @@ class MMParallel {
// Calls `func(row_a, worker)` in parallel.
template <class Func>
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
const Func& func) {
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
ctx_.pools.Pool(pkg_idx).Run(
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, const Func& func) {
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
ctx.pools.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
}
private:
ThreadingContext& ctx_;
};
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
// C is BF16/float.
void BindC(MatPtr& C, MMParallel& parallel);
void BindC(ThreadingContext& ctx, MatPtr& C);
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
#pragma pack(push, 1) // power of two size
@ -250,15 +245,18 @@ class MMStorage {
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel) {
MMStorage(ThreadingContext& ctx) {
// Per-package allocation so each can decompress A into its own copy.
// Must be padded, see `DoDecompressA`.
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) {
// Default to nested parallel policy.
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
Allocator& allocator = ctx.allocator;
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx);
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
pkg_A_[pkg_idx]->ElementBytes();
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
@ -607,9 +605,9 @@ class MMKeys {
// Per-MatMul-shape state.
struct MMPerKey {
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
MMParallel& parallel)
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {
MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr)
: ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) {
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
}
@ -639,7 +637,6 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished.
bool print_best = false;
MMParallel parallel;
MMStorage storage;
MMKeys keys;
std::vector<MMPerKey> per_key;