Add in-cluster parallel policy. Update policy to include cluster_idx.

PiperOrigin-RevId: 802016308
This commit is contained in:
Marie White 2025-09-02 00:14:05 -07:00 committed by Copybara-Service
parent 27cb8e12d9
commit 3737224132
2 changed files with 96 additions and 27 deletions

View File

@ -728,9 +728,10 @@ class MMPerPackage {
public:
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
size_t pkg_idx, const IndexRange& range_np)
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np)
: args_(args),
pkg_idx_(pkg_idx),
cluster_idx_(cluster_idx),
range_np_(range_np),
mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.rows)),
@ -821,7 +822,8 @@ class MMPerPackage {
// Similar to `loop_nc` below, but here we hoisted `A_view`.
MMParallelPolicyT::ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
pkg_idx_, cluster_idx_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
@ -869,7 +871,8 @@ class MMPerPackage {
MMParallelPolicyT::ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
pkg_idx_, cluster_idx_,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
@ -901,7 +904,7 @@ class MMPerPackage {
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
// except for the profiler strings and `out_tag`.
MMParallelPolicyT::ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
@ -951,7 +954,7 @@ class MMPerPackage {
}
}; // loop_nc
MMParallelPolicyT::ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) HWY_ATTR {
MMZone mm_zone;
@ -1024,7 +1027,7 @@ class MMPerPackage {
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
pkg_idx_,
pkg_idx_, cluster_idx_,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
});
@ -1032,7 +1035,8 @@ class MMPerPackage {
}
case MMParA::kM:
MMParallelPolicyT::ForRangeMC(
args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
args_.env->ctx, all_M, pkg_idx_, cluster_idx_,
[&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
});
break;
@ -1106,6 +1110,7 @@ class MMPerPackage {
const MMArgs args_; // copy for locality
const size_t pkg_idx_;
const size_t cluster_idx_; // 0 for sequential and nested.
const IndexRange range_np_;
// From MMConfig:
@ -1135,23 +1140,26 @@ struct MMImpl {
template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config,
ParallelismType parallelism_type) {
const MMConfig& config, MMOptions options) {
PROFILER_ZONE("MM.DoMatMul");
const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
switch (parallelism_type) {
switch (options.parallelism_type) {
case ParallelismType::kNested:
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
HWY_DASSERT(options.cluster_idx == 0);
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMNestedParallelPolicy(), A, B, C_rows);
break;
case ParallelismType::kSequential:
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMSequentialPolicy(), A, B, C_rows);
case ParallelismType::kNone:
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMSequentialPolicy(), A, B, C_rows);
case ParallelismType::kCluster:
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
break;
default:
HWY_ABORT("Parallelism type not implemented.");
break;
}
@ -1210,8 +1218,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add);
if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(),
options.parallelism_type);
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options);
return &per_key;
}
@ -1240,7 +1247,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type);
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options);
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /

View File

@ -81,9 +81,10 @@ struct MMSequentialPolicy {
template <class Func>
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
const Func& func) {
size_t cluster_idx, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
cluster_idx * ctx.pools.MaxWorkersPerCluster();
func(range_np, base_idx);
}
@ -91,8 +92,10 @@ struct MMSequentialPolicy {
static void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
size_t pkg_idx, const Func& func) {
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
size_t pkg_idx, size_t cluster_idx,
const Func& func) {
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
cluster_idx * ctx.pools.MaxWorkersPerCluster();
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
const IndexRange range_mc = ranges_mc.Range(i);
@ -105,14 +108,68 @@ struct MMSequentialPolicy {
template <class Func>
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, const Func& func) {
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
size_t pkg_idx, size_t cluster_idx, const Func& func) {
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
cluster_idx * ctx.pools.MaxWorkersPerCluster();
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
func(row_a, base_idx);
}
}
};
struct MMClusterParallelPolicy {
template <class Func>
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
const Func& func) {
func(/*pkg_idx=*/0);
}
template <class Func>
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
size_t cluster_idx, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
ParallelizeOneRange(worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker);
});
}
template <class Func>
static void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
size_t pkg_idx, size_t cluster_idx,
const Func& func) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
ParallelizeOneRange(ranges_nc, cluster,
[&](const IndexRange& range_nc, size_t thread) {
func(ranges_mc.Range(0), range_nc, thread);
});
} else {
ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t thread) { func(range_mc, range_nc, thread); });
}
}
template <class Func>
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, size_t cluster_idx, const Func& func) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
cluster.Run(range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t thread) { func(row_a, thread); });
}
};
struct MMNestedParallelPolicy {
template <class Func>
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
@ -132,10 +189,11 @@ struct MMNestedParallelPolicy {
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
// `cluster_idx` is not used here as all clusters within a package are used.
template <class Func>
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
const Func& func) {
size_t /*cluster_idx*/, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
@ -175,11 +233,13 @@ struct MMNestedParallelPolicy {
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
// rows). Calls `func(range_mc, range_nc, worker)`.
// `cluster_idx` is not used here as all clusters within a package are used.
template <class Func>
static void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
size_t pkg_idx, const Func& func) {
size_t pkg_idx, size_t /*cluster_idx*/,
const Func& func) {
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package.
@ -221,9 +281,11 @@ struct MMNestedParallelPolicy {
}
// Calls `func(row_a, worker)` in parallel.
// `cluster_idx` is not used here as all clusters within a package are used.
template <class Func>
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t pkg_idx, const Func& func) {
size_t pkg_idx, size_t /*cluster_idx*/,
const Func& func) {
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
ctx.pools.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(),