diff --git a/ops/matmul.h b/ops/matmul.h index ea7c090..fcb3063 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -103,18 +103,14 @@ struct MMParallelWithinCluster { template void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, size_t inner_tasks, size_t cluster_idx, const Func& func) const { - HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForN); - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t base = ctx.Worker(cluster_idx); - - const IndexRangePartition ranges_n = StaticPartition( - range_n, cluster.NumWorkers() * inner_tasks, n_multiple); - ParallelizeOneRange(ranges_n, cluster, - ctx.pool_callers.Get(Callers::kMMClusterForN), - [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, base + worker); - }); + ParallelPartitionWithinCluster( + range_n, n_multiple, inner_tasks, ctx, cluster_idx, caller, + [&](const IndexRange& worker_range, size_t worker) { + func(worker_range, worker); + }); } template @@ -122,79 +118,56 @@ struct MMParallelWithinCluster { const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, size_t cluster_idx, const Func& func) const { - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t base = ctx.Worker(cluster_idx); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForMCNC); - // Low-batch: avoid Divide/Remainder. - if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { - ParallelizeOneRange(ranges_nc, cluster, - ctx.pool_callers.Get(Callers::kMMClusterForMCNC), - [&](const IndexRange& range_nc, size_t worker) { - func(ranges_mc.Range(0), range_nc, base + worker); - }); - } else { - ParallelizeTwoRanges( - ranges_mc, ranges_nc, cluster, - ctx.pool_callers.Get(Callers::kMMClusterForMCNC), - [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t worker) { func(range_mc, range_nc, base + worker); }); - } + // We are running on one pool, hence collapse into a 1D range. + const hwy::Divisor div_m(static_cast(ranges_mc.NumTasks())); + const auto get_mc = [&](uint64_t task) { + return ranges_mc.Range(div_m.Remainder(static_cast(task))); + }; + const auto get_nc = [&](uint64_t task) { + return ranges_nc.Range(div_m.Divide(static_cast(task))); + }; + const size_t num_tasks = ranges_mc.NumTasks() * ranges_nc.NumTasks(); + + ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + [&](uint64_t task, size_t worker) { + func(get_mc(task), get_nc(task), worker); + }); } template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t base = ctx.Worker(cluster_idx); + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kMMClusterForMC); - cluster.Run( - range_mc.begin(), range_mc.end(), - ctx.pool_callers.Get(Callers::kMMClusterForMC), - [&](uint64_t row_a, size_t worker) { func(row_a, base + worker); }); + ParallelForWithinCluster( + range_mc.Num(), ctx, cluster_idx, caller, + [&](uint64_t i, size_t worker) { func(range_mc.begin() + i, worker); }); } }; struct MMParallelHierarchical { - // Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is - // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. + // Similar to `HierarchicalParallelFor`, but over *sub-ranges* of B rows in + // `range_n` governed by `n_multiple` and `inner_tasks`. template void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, - size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx, + size_t inner_tasks, size_t caller_cluster_idx, const Func& func) const { - HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForN); - // Single cluster: parallel-for over static partition of `range_n`. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - if (num_clusters == 1) { - const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const IndexRangePartition ranges_n = StaticPartition( - range_n, cluster.NumWorkers() * inner_tasks, n_multiple); - return ParallelizeOneRange( - ranges_n, cluster, caller, - [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, worker); - }); - } - - // Assign each cluster a sub-range of `range_n` (typically hundreds). - const IndexRangePartition ranges_n = - StaticPartition(range_n, num_clusters, n_multiple); - ParallelizeOneRange( - ranges_n, all_clusters, caller, - [&](const IndexRange& n_range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - const size_t cluster_base = ctx.Worker(cluster_idx); - // Parallel-for over sub-ranges of `cluster_range` within the cluster. - const IndexRangePartition worker_ranges = StaticPartition( - n_range, cluster.NumWorkers() * inner_tasks, n_multiple); - ParallelizeOneRange( - worker_ranges, cluster, caller, + // Assign clusters (if any) a sub-range of `range_n` (typically hundreds). + ParallelPartitionAcrossClusters( + range_n, n_multiple, /*inner_tasks=*/1, ctx, caller, + [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelPartitionWithinCluster( + cluster_range, n_multiple, inner_tasks, ctx, cluster_idx, caller, [&](const IndexRange& worker_range, size_t worker) { - func(worker_range, cluster_base + worker); + func(worker_range, worker); }); }); } @@ -205,57 +178,44 @@ struct MMParallelHierarchical { void ForRangesMC_NC(ThreadingContext& ctx, const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, - HWY_MAYBE_UNUSED size_t caller_cluster_idx, - const Func& func) const { + size_t caller_cluster_idx, const Func& func) const { HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; const hwy::pool::Caller caller = ctx.pool_callers.Get(Callers::kMMHierForMCNC); - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - // `all_clusters` is a pool with one worker per cluster in a package. - const size_t num_clusters = all_clusters.NumWorkers(); - // Single (big) cluster: collapse two range indices into one parallel-for - // to reduce the number of fork-joins. - if (num_clusters == 1) { - const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - // Low-batch: avoid Divide/Remainder. - if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { - return ParallelizeOneRange( - ranges_nc, cluster, caller, - [&](const IndexRange& range_nc, size_t worker) { - func(ranges_mc.Range(0), range_nc, worker); - }); - } else { - return ParallelizeTwoRanges( - ranges_mc, ranges_nc, cluster, caller, - [&](const IndexRange& range_mc, const IndexRange& range_nc, - size_t worker) { func(range_mc, range_nc, worker); }); - } - } + // Collapse two range indices into a 1D range for better load-balancing, + // because `ranges_mc` may just have one task. + const hwy::Divisor div_m(static_cast(ranges_mc.NumTasks())); + const auto get_mc = [&](uint64_t task) { + return ranges_mc.Range(div_m.Remainder(static_cast(task))); + }; + const auto get_nc = [&](uint64_t task) { + return ranges_nc.Range(div_m.Divide(static_cast(task))); + }; + const IndexRange all_range(0, ranges_mc.NumTasks() * ranges_nc.NumTasks()); - // Multiple clusters: N across clusters (both are usually the larger), and - // M within each cluster. We assume auto-tuning finds small MC/NC tasks. - ParallelizeOneRange( - ranges_nc, all_clusters, caller, - [&](const IndexRange range_nc, size_t cluster_idx) { - const size_t cluster_base = ctx.Worker(cluster_idx); - hwy::ThreadPool& cluster = ctx.pools.Cluster(cluster_idx); - ParallelizeOneRange(ranges_mc, cluster, caller, - [&](const IndexRange& range_mc, size_t worker) { - func(range_mc, range_nc, cluster_base + worker); - }); + ParallelPartitionAcrossClusters( + all_range, /*task_multiple=*/1, /*inner_tasks=*/1, ctx, caller, + [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, + caller, [&](uint64_t i, size_t worker) { + const size_t task = + cluster_range.begin() + i; + func(get_mc(task), get_nc(task), worker); + }); }); } - // Calls `func(row_a, worker)` in parallel. + // No multiple/inner_tasks, so this is just HierarchicalParallelFor. template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t caller_cluster_idx, const Func& func) const { - HierarchicalParallelFor(range_mc.Num(), ctx, Callers::kMMHierForMC, - [&](size_t task, size_t worker) { - func(range_mc.begin() + task, worker); - }); + HWY_DASSERT(caller_cluster_idx == 0); + (void)caller_cluster_idx; + HierarchicalParallelFor( + range_mc.Num(), ctx, Callers::kMMHierForMC, + [&](size_t i, size_t worker) { func(range_mc.begin() + i, worker); }); } }; diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 2aaf301..4787122 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -195,9 +195,10 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB); const IndexRangePartition get_col_c = StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); - ParallelizeOneRange( - get_col_c, all_clusters, env.ctx.pool_callers.Get(Callers::kTest), - [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { + ParallelForAcrossClusters( + get_col_c.NumTasks(), env.ctx, env.ctx.pool_callers.Get(Callers::kTest), + [&](size_t range_idx, size_t cluster_idx) HWY_ATTR { + const IndexRange cols_c = get_col_c.Range(range_idx); for (size_t r : all_rows_c) { TC* HWY_RESTRICT C_row = C.Row(r); for (size_t c : cols_c) { diff --git a/util/threading.h b/util/threading.h index dcdcf24..b18ad24 100644 --- a/util/threading.h +++ b/util/threading.h @@ -262,43 +262,6 @@ static inline IndexRangePartition StaticPartition(const IndexRange& range, return IndexRangePartition(range, size); } -// Parallel-for over a single range. This takes care of translating the task -// index to a range. -template -void ParallelizeOneRange(const IndexRangePartition& get1, hwy::ThreadPool& pool, - hwy::pool::Caller caller, const Func& func) { - const size_t num_tasks = get1.NumTasks(); - pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { - const IndexRange range1 = get1.Range(task); - func(range1, thread); - }); -} - -// Parallel-for over the Cartesian product of the two sets of ranges. This -// combines their indices into a single 'task' so they can be executed by one -// `pool.Run`, which increases the amount of work available to workers and -// reduces fork-join overhead vs. nested parallel-for loops. Calls `func` with -// the two ranges and the thread index within `pool`. -template -void ParallelizeTwoRanges(const IndexRangePartition& get1, - const IndexRangePartition& get2, - hwy::ThreadPool& pool, hwy::pool::Caller caller, - const Func& func) { - const hwy::Divisor div1(static_cast(get1.NumTasks())); - - const size_t num_tasks = get1.NumTasks() * get2.NumTasks(); - pool.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { - HWY_DASSERT(task < (uint64_t{1} << 32)); - const size_t idx2 = div1.Divide(static_cast(task)); - const size_t idx1 = div1.Remainder(static_cast(task)); - HWY_DASSERT(idx1 < get1.NumTasks()); - HWY_DASSERT(idx2 < get2.NumTasks()); - const IndexRange range1 = get1.Range(idx1); - const IndexRange range2 = get2.Range(idx2); - func(range1, range2, thread); - }); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ diff --git a/util/threading_context.h b/util/threading_context.h index 251888f..b3e2a52 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -154,42 +154,96 @@ enum class ParallelismStrategy : uint8_t { kHierarchical, }; -// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes -// over clusters of ONE package, then within each cluster. +// Helper functions used to implement `ParallelFor`, also reused in multiple +// places. User code should call `ParallelFor` instead, which accepts the more +// convenient `Callers` enum. +// +// These call `func(task, worker)` for each task in `[0, num_tasks)`. + +// NOTE: the worker argument is actually the `cluster_idx`, so that `Func` can +// pass that to `ParallelForWithinCluster`. +template +void ParallelForAcrossClusters(size_t num_tasks, ThreadingContext& ctx, + hwy::pool::Caller caller, const Func& func) { + ctx.pools.AllClusters().Run( + 0, num_tasks, caller, + [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); +} + +template +void ParallelForWithinCluster(size_t num_tasks, ThreadingContext& ctx, + size_t cluster_idx, hwy::pool::Caller caller, + const Func& func) { + const size_t cluster_base = ctx.Worker(cluster_idx); + ctx.pools.Cluster(cluster_idx) + .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) { + func(task, cluster_base + worker); + }); +} + +// Calls `func(range, cluster_idx)`, for passing to `*WithinCluster`. +template +void ParallelPartitionAcrossClusters(const IndexRange range, + size_t task_multiple, size_t inner_tasks, + ThreadingContext& ctx, + hwy::pool::Caller caller, + const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const IndexRangePartition ranges = StaticPartition( + range, ctx.pools.NumClusters() * inner_tasks, task_multiple); + ParallelForAcrossClusters(ranges.NumTasks(), ctx, caller, + [&](uint64_t task, size_t cluster_idx) { + func(ranges.Range(task), cluster_idx); + }); +} + +// Calls `func(range, worker)`. +template +void ParallelPartitionWithinCluster(const IndexRange range, + size_t task_multiple, size_t inner_tasks, + ThreadingContext& ctx, size_t cluster_idx, + hwy::pool::Caller caller, + const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const size_t num_workers = ctx.pools.Cluster(cluster_idx).NumWorkers(); + const IndexRangePartition ranges = + StaticPartition(range, num_workers * inner_tasks, task_multiple); + ParallelForWithinCluster( + ranges.NumTasks(), ctx, cluster_idx, caller, + [&](uint64_t task, size_t worker) { func(ranges.Range(task), worker); }); +} + +// Parallelizes across clusters, then within each cluster. template void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx, Callers callers, const Func& func) { const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); - // If few tasks, run on a single cluster. Also avoids a bit of overhead if - // there is only one cluster. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - hwy::ThreadPool& cluster = ctx.pools.Cluster(0); - if (num_clusters == 1 || num_tasks <= cluster.NumWorkers()) { - return cluster.Run(0, num_tasks, caller, [&](uint64_t task, size_t thread) { - func(task, thread); - }); + + // If at most one task per cluster worker, run on a single cluster to avoid + // the expensive cross-cluster barrier. + { + const size_t cluster_idx = 0; + const size_t cluster_workers = ctx.pools.Cluster(cluster_idx).NumWorkers(); + if (HWY_UNLIKELY(num_tasks <= cluster_workers)) { + return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + func); + } } - // Assign each cluster a sub-range. - const IndexRangePartition ranges = - StaticPartition(IndexRange(0, num_tasks), num_clusters, 1); - ParallelizeOneRange(ranges, all_clusters, caller, - [&](const IndexRange& range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = - ctx.pools.Cluster(cluster_idx); - const size_t cluster_base = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); - cluster.Run(range.begin(), range.end(), caller, - [&](uint64_t task, size_t thread) { - func(task, cluster_base + thread); - }); - }); + ParallelPartitionAcrossClusters( + IndexRange(0, num_tasks), /*task_multiple=*/1, /*inner_tasks=*/1, ctx, + caller, [&](const IndexRange& cluster_range, size_t cluster_idx) { + ParallelForWithinCluster(cluster_range.Num(), ctx, cluster_idx, caller, + [&](uint64_t i, size_t worker) { + func(cluster_range.begin() + i, worker); + }); + }); } // Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the -// number/type of workers determined by `parallelism`. `cluster_idx` is for -// `parallelism == kWithinCluster`, and should be 0 if unknown. +// number/type of workers determined by `parallelism`. NOTE: worker is actually +// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for +// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown. template void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, ThreadingContext& ctx, size_t cluster_idx, Callers callers, @@ -212,37 +266,25 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, } case ParallelismStrategy::kAcrossClusters: - return ctx.pools.AllClusters().Run( - 0, num_tasks, caller, + return ParallelForAcrossClusters( + num_tasks, ctx, caller, [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); - case ParallelismStrategy::kWithinCluster: { - // Ensure the worker argument is unique across clusters, because it is - // used for TLS indexing for example in profiler.h. - const size_t base = ctx.Worker(cluster_idx); - return ctx.pools.Cluster(cluster_idx) - .Run(0, num_tasks, caller, [&](uint64_t task, size_t worker) { - func(task, base + worker); - }); - } + case ParallelismStrategy::kWithinCluster: + return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + func); - case ParallelismStrategy::kFlat: { - // Check for single cluster; if not, we must compute `cluster_base` for - // consistent and non-overlapping worker indices. - hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(); - const size_t num_clusters = all_clusters.NumWorkers(); - if (num_clusters == 1) { - return ctx.pools.Cluster(cluster_idx) - .Run(0, num_tasks, caller, - [&](uint64_t task, size_t worker) { func(task, worker); }); + case ParallelismStrategy::kFlat: + // Choose a single pool: the only cluster, or across all clusters + // (slower synchronization, but more memory bandwidth) + if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) { + return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, + func); } - - return all_clusters.Run(0, num_tasks, caller, - [&](uint64_t task, size_t cluster_idx) { - const size_t worker = ctx.Worker(cluster_idx); - func(task, worker); - }); - } + return ParallelForAcrossClusters(num_tasks, ctx, caller, + [&](uint64_t task, size_t cluster_idx) { + func(task, ctx.Worker(cluster_idx)); + }); case ParallelismStrategy::kHierarchical: return HierarchicalParallelFor(num_tasks, ctx, callers, func); diff --git a/util/threading_test.cc b/util/threading_test.cc index 14ea1a0..d6b98a3 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -202,57 +202,6 @@ TEST(ThreadingTest, TestStaticPartition) { } } -TEST(ThreadingTest, TestParallelizeOneRange) { - const IndexRange range(0, 10); - const IndexRangePartition partition = StaticPartition(range, 2, 4); - hwy::ThreadPool null_pool(0); - size_t calls = 0; - ParallelizeOneRange(partition, null_pool, kCaller, - [&](const IndexRange& range, size_t) { - if (++calls == 1) { - HWY_ASSERT(range.begin() == 0 && range.end() == 8); - } else { - HWY_ASSERT(range.begin() == 8 && range.end() == 10); - } - }); - HWY_ASSERT(calls == 2); -} - -TEST(ThreadingTest, TestParallelizeTwoRanges) { - const IndexRangePartition partition1 = - StaticPartition(IndexRange(0, 10), 2, 4); - const IndexRangePartition partition2 = - MaxSizePartition(IndexRange(128, 256), 32, 32); - HWY_ASSERT(partition2.NumTasks() == 4); - hwy::ThreadPool null_pool(0); - { - size_t calls = 0; - ParallelizeTwoRanges( - partition1, partition2, null_pool, kCaller, - [&](const IndexRange& range1, const IndexRange& range2, size_t) { - ++calls; - HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); - HWY_ASSERT(range2.begin() % 32 == 0); - HWY_ASSERT(range2.Num() % 32 == 0); - }); - HWY_ASSERT(calls == 2 * 4); - } - - // Also swap order to test Remainder() logic. - { - size_t calls = 0; - ParallelizeTwoRanges( - partition2, partition1, null_pool, kCaller, - [&](const IndexRange& range2, const IndexRange& range1, size_t) { - ++calls; - HWY_ASSERT(range1.begin() == 0 || range1.begin() == 8); - HWY_ASSERT(range2.begin() % 32 == 0); - HWY_ASSERT(range2.Num() % 32 == 0); - }); - HWY_ASSERT(calls == 2 * 4); - } -} - static constexpr size_t kU64PerThread = HWY_ALIGNMENT / sizeof(size_t); static uint64_t outputs[hwy::kMaxLogicalProcessors * kU64PerThread];