1.02x speedup: improve load balance and simplify parallelFor

Remove ParallelizeOne/TwoRange, use ParallelForAcross/WithinCluster instead.

PiperOrigin-RevId: 823388890
This commit is contained in:
Jan Wassenberg 2025-10-24 00:17:45 -07:00 committed by Copybara-Service
parent 085a34965a
commit a48e614f64
5 changed files with 166 additions and 251 deletions

View File

@ -103,17 +103,13 @@ struct MMParallelWithinCluster {
template <class Func>
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, size_t cluster_idx, const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMClusterForN);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(ranges_n, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForN),
ParallelPartitionWithinCluster(
range_n, n_multiple, inner_tasks, ctx, cluster_idx, caller,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker);
func(worker_range, worker);
});
}
@ -122,80 +118,57 @@ struct MMParallelWithinCluster {
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMClusterForMCNC);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
ParallelizeOneRange(ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, base + worker);
// We are running on one pool, hence collapse into a 1D range.
const hwy::Divisor div_m(static_cast<uint32_t>(ranges_mc.NumTasks()));
const auto get_mc = [&](uint64_t task) {
return ranges_mc.Range(div_m.Remainder(static_cast<uint32_t>(task)));
};
const auto get_nc = [&](uint64_t task) {
return ranges_nc.Range(div_m.Divide(static_cast<uint32_t>(task)));
};
const size_t num_tasks = ranges_mc.NumTasks() * ranges_nc.NumTasks();
ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
[&](uint64_t task, size_t worker) {
func(get_mc(task), get_nc(task), worker);
});
} else {
ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster,
ctx.pool_callers.Get(Callers::kMMClusterForMCNC),
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, base + worker); });
}
}
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t base = ctx.Worker(cluster_idx);
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMClusterForMC);
cluster.Run(
range_mc.begin(), range_mc.end(),
ctx.pool_callers.Get(Callers::kMMClusterForMC),
[&](uint64_t row_a, size_t worker) { func(row_a, base + worker); });
ParallelForWithinCluster(
range_mc.Num(), ctx, cluster_idx, caller,
[&](uint64_t i, size_t worker) { func(range_mc.begin() + i, worker); });
}
};
struct MMParallelHierarchical {
// Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
// Similar to `HierarchicalParallelFor`, but over *sub-ranges* of B rows in
// `range_n` governed by `n_multiple` and `inner_tasks`.
template <class Func>
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx,
size_t inner_tasks, size_t caller_cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
HWY_DASSERT(caller_cluster_idx == 0);
(void)caller_cluster_idx;
const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN);
// Single cluster: parallel-for over static partition of `range_n`.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const IndexRangePartition ranges_n = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange(
ranges_n, cluster, caller,
// Assign clusters (if any) a sub-range of `range_n` (typically hundreds).
ParallelPartitionAcrossClusters(
range_n, n_multiple, /*inner_tasks=*/1, ctx, caller,
[&](const IndexRange& cluster_range, size_t cluster_idx) {
ParallelPartitionWithinCluster(
cluster_range, n_multiple, inner_tasks, ctx, cluster_idx, caller,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, worker);
});
}
// Assign each cluster a sub-range of `range_n` (typically hundreds).
const IndexRangePartition ranges_n =
StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange(
ranges_n, all_clusters, caller,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
const size_t cluster_base = ctx.Worker(cluster_idx);
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(
worker_ranges, cluster, caller,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, cluster_base + worker);
});
});
}
@ -205,57 +178,44 @@ struct MMParallelHierarchical {
void ForRangesMC_NC(ThreadingContext& ctx,
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc,
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const {
size_t caller_cluster_idx, const Func& func) const {
HWY_DASSERT(caller_cluster_idx == 0);
(void)caller_cluster_idx;
const hwy::pool::Caller caller =
ctx.pool_callers.Get(Callers::kMMHierForMCNC);
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
// `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers();
// Single (big) cluster: collapse two range indices into one parallel-for
// to reduce the number of fork-joins.
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange(
ranges_nc, cluster, caller,
[&](const IndexRange& range_nc, size_t worker) {
func(ranges_mc.Range(0), range_nc, worker);
});
} else {
return ParallelizeTwoRanges(
ranges_mc, ranges_nc, cluster, caller,
[&](const IndexRange& range_mc, const IndexRange& range_nc,
size_t worker) { func(range_mc, range_nc, worker); });
}
}
// Collapse two range indices into a 1D range for better load-balancing,
// because `ranges_mc` may just have one task.
const hwy::Divisor div_m(static_cast<uint32_t>(ranges_mc.NumTasks()));
const auto get_mc = [&](uint64_t task) {
return ranges_mc.Range(div_m.Remainder(static_cast<uint32_t>(task)));
};
const auto get_nc = [&](uint64_t task) {
return ranges_nc.Range(div_m.Divide(static_cast<uint32_t>(task)));
};
const IndexRange all_range(0, ranges_mc.NumTasks() * ranges_nc.NumTasks());
// Multiple clusters: N across clusters (both are usually the larger), and
// M within each cluster. We assume auto-tuning finds small MC/NC tasks.
ParallelizeOneRange(
ranges_nc, all_clusters, caller,
[&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base = ctx.Worker(cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx);
ParallelizeOneRange(ranges_mc, cluster, caller,
[&](const IndexRange& range_mc, size_t worker) {
func(range_mc, range_nc, cluster_base + worker);
ParallelPartitionAcrossClusters(
all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller,
[&](const IndexRange& cluster_range, size_t cluster_idx) {
ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx,
caller, [&](uint64_t i, size_t worker) {
const size_t task =
cluster_range.begin() + i;
func(get_mc(task), get_nc(task), worker);
});
});
}
// Calls `func(row_a, worker)` in parallel.
// No multiple/inner_tasks, so this is just HierarchicalParallelFor.
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t caller_cluster_idx, const Func& func) const {
HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC,
[&](size_t task, size_t worker) {
func(range_mc.begin() + task, worker);
});
HWY_DASSERT(caller_cluster_idx == 0);
(void)caller_cluster_idx;
HierarchicalParallelFor(
range_mc.Num(), ctx, Callers::kMMHierForMC,
[&](size_t i, size_t worker) { func(range_mc.begin() + i, worker); });
}
};

View File

@ -195,9 +195,10 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest),
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
ParallelForAcrossClusters(
get_col_c.NumTasks(), env.ctx, env.ctx.pool_callers.Get(Callers::kTest),
[&](size_t range_idx, size_t cluster_idx) HWY_ATTR {
const IndexRange cols_c = get_col_c.Range(range_idx);
for (size_t r : all_rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {

View File

@ -262,43 +262,6 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range,
return IndexRangePartition(range, size);
}
// Parallel-for over a single range. This takes care of translating the task
// index to a range.
template <class Func>
void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool,
hwy::pool::Caller caller, const Func& func) {
const size_t num_tasks = get1.NumTasks();
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
const IndexRange range1 = get1.Range(task);
func(range1, thread);
});
}
// Parallel-for over the Cartesian product of the two sets of ranges. This
// combines their indices into a single 'task' so they can be executed by one
// `pool.Run`, which increases the amount of work available to workers and
// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with
// the two ranges and the thread index within `pool`.
template <class Func>
void ParallelizeTwoRanges(const IndexRangePartition& get1,
const IndexRangePartition& get2,
hwy::ThreadPool& pool, hwy::pool::Caller caller,
const Func& func) {
const hwy::Divisor div1(static_cast<uint32_t>(get1.NumTasks()));
const size_t num_tasks = get1.NumTasks() * get2.NumTasks();
pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
HWY_DASSERT(task < (uint64_t{1} << 32));
const size_t idx2 = div1.Divide(static_cast<uint32_t>(task));
const size_t idx1 = div1.Remainder(static_cast<uint32_t>(task));
HWY_DASSERT(idx1 < get1.NumTasks());
HWY_DASSERT(idx2 < get2.NumTasks());
const IndexRange range1 = get1.Range(idx1);
const IndexRange range2 = get2.Range(idx2);
func(range1, range2, thread);
});
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_

View File

@ -154,42 +154,96 @@ enum class ParallelismStrategy : uint8_t {
kHierarchical,
};
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster.
// Helper functions used to implement `ParallelFor`, also reused in multiple
// places. User code should call `ParallelFor` instead, which accepts the more
// convenient `Callers` enum.
//
// These call `func(task, worker)` for each task in `[0, num_tasks)`.
// NOTE: the worker argument is actually the `cluster_idx`, so that `Func` can
// pass that to `ParallelForWithinCluster`.
template <class Func>
void ParallelForAcrossClusters(size_t num_tasks, ThreadingContext& ctx,
hwy::pool::Caller caller, const Func& func) {
ctx.pools.AllClusters().Run(
0, num_tasks, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
}
template <class Func>
void ParallelForWithinCluster(size_t num_tasks, ThreadingContext& ctx,
size_t cluster_idx, hwy::pool::Caller caller,
const Func& func) {
const size_t cluster_base = ctx.Worker(cluster_idx);
ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
func(task, cluster_base + worker);
});
}
// Calls `func(range, cluster_idx)`, for passing to `*WithinCluster`.
template <class Func>
void ParallelPartitionAcrossClusters(const IndexRange range,
size_t task_multiple, size_t inner_tasks,
ThreadingContext& ctx,
hwy::pool::Caller caller,
const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const IndexRangePartition ranges = StaticPartition(
range, ctx.pools.NumClusters() * inner_tasks, task_multiple);
ParallelForAcrossClusters(ranges.NumTasks(), ctx, caller,
[&](uint64_t task, size_t cluster_idx) {
func(ranges.Range(task), cluster_idx);
});
}
// Calls `func(range, worker)`.
template <class Func>
void ParallelPartitionWithinCluster(const IndexRange range,
size_t task_multiple, size_t inner_tasks,
ThreadingContext& ctx, size_t cluster_idx,
hwy::pool::Caller caller,
const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t num_workers = ctx.pools.Cluster(cluster_idx).NumWorkers();
const IndexRangePartition ranges =
StaticPartition(range, num_workers * inner_tasks, task_multiple);
ParallelForWithinCluster(
ranges.NumTasks(), ctx, cluster_idx, caller,
[&](uint64_t task, size_t worker) { func(ranges.Range(task), worker); });
}
// Parallelizes across clusters, then within each cluster.
template <class Func>
void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
Callers callers, const Func& func) {
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
// If few tasks, run on a single cluster. Also avoids a bit of overhead if
// there is only one cluster.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
hwy::ThreadPool& cluster = ctx.pools.Cluster(0);
if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) {
return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) {
func(task, thread);
});
// If at most one task per cluster worker, run on a single cluster to avoid
// the expensive cross-cluster barrier.
{
const size_t cluster_idx = 0;
const size_t cluster_workers = ctx.pools.Cluster(cluster_idx).NumWorkers();
if (HWY_UNLIKELY(num_tasks <= cluster_workers)) {
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
func);
}
}
// Assign each cluster a sub-range.
const IndexRangePartition ranges =
StaticPartition(IndexRange(0, num_tasks), num_clusters, 1);
ParallelizeOneRange(ranges, all_clusters, caller,
[&](const IndexRange& range, const size_t cluster_idx) {
hwy::ThreadPool& cluster =
ctx.pools.Cluster(cluster_idx);
const size_t cluster_base =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
cluster.Run(range.begin(), range.end(), caller,
[&](uint64_t task, size_t thread) {
func(task, cluster_base + thread);
ParallelPartitionAcrossClusters(
IndexRange(0, num_tasks), /*task_multiple=*/1, /*inner_tasks=*/1, ctx,
caller, [&](const IndexRange& cluster_range, size_t cluster_idx) {
ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, caller,
[&](uint64_t i, size_t worker) {
func(cluster_range.begin() + i, worker);
});
});
}
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
// number/type of workers determined by `parallelism`. `cluster_idx` is for
// `parallelism == kWithinCluster`, and should be 0 if unknown.
// number/type of workers determined by `parallelism`. NOTE: worker is actually
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
@ -212,37 +266,25 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
}
case ParallelismStrategy::kAcrossClusters:
return ctx.pools.AllClusters().Run(
0, num_tasks, caller,
return ParallelForAcrossClusters(
num_tasks, ctx, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: {
// Ensure the worker argument is unique across clusters, because it is
// used for TLS indexing for example in profiler.h.
const size_t base = ctx.Worker(cluster_idx);
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) {
func(task, base + worker);
});
}
case ParallelismStrategy::kWithinCluster:
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
func);
case ParallelismStrategy::kFlat: {
// Check for single cluster; if not, we must compute `cluster_base` for
// consistent and non-overlapping worker indices.
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters();
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
return ctx.pools.Cluster(cluster_idx)
.Run(0, num_tasks, caller,
[&](uint64_t task, size_t worker) { func(task, worker); });
case ParallelismStrategy::kFlat:
// Choose a single pool: the only cluster, or across all clusters
// (slower synchronization, but more memory bandwidth)
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
func);
}
return all_clusters.Run(0, num_tasks, caller,
return ParallelForAcrossClusters(num_tasks, ctx, caller,
[&](uint64_t task, size_t cluster_idx) {
const size_t worker = ctx.Worker(cluster_idx);
func(task, worker);
func(task, ctx.Worker(cluster_idx));
});
}
case ParallelismStrategy::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx, callers, func);

View File

@ -202,57 +202,6 @@ TEST(ThreadingTest, TestStaticPartition) {
}
}
TEST(ThreadingTest, TestParallelizeOneRange) {
const IndexRange range(0, 10);
const IndexRangePartition partition = StaticPartition(range, 2, 4);
hwy::ThreadPool null_pool(0);
size_t calls = 0;
ParallelizeOneRange(partition, null_pool, kCaller,
[&](const IndexRange& range, size_t) {
if (++calls == 1) {
HWY_ASSERT(range.begin() == 0 && range.end() == 8);
} else {
HWY_ASSERT(range.begin() == 8 && range.end() == 10);
}
});
HWY_ASSERT(calls == 2);
}
TEST(ThreadingTest, TestParallelizeTwoRanges) {
const IndexRangePartition partition1 =
StaticPartition(IndexRange(0, 10), 2, 4);
const IndexRangePartition partition2 =
MaxSizePartition(IndexRange(128, 256), 32, 32);
HWY_ASSERT(partition2.NumTasks() == 4);
hwy::ThreadPool null_pool(0);
{
size_t calls = 0;
ParallelizeTwoRanges(
partition1, partition2, null_pool, kCaller,
[&](const IndexRange& range1, const IndexRange& range2, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
HWY_ASSERT(range2.begin() % 32 == 0);
HWY_ASSERT(range2.Num() % 32 == 0);
});
HWY_ASSERT(calls == 2 * 4);
}
// Also swap order to test Remainder() logic.
{
size_t calls = 0;
ParallelizeTwoRanges(
partition2, partition1, null_pool, kCaller,
[&](const IndexRange& range2, const IndexRange& range1, size_t) {
++calls;
HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8);
HWY_ASSERT(range2.begin() % 32 == 0);
HWY_ASSERT(range2.Num() % 32 == 0);
});
HWY_ASSERT(calls == 2 * 4);
}
}
static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t);
static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];