Minor: ParallelismStrategy->Parallelism

PiperOrigin-RevId: 828936578
This commit is contained in:
Jan Wassenberg 2025-11-06 06:55:37 -08:00 committed by Copybara-Service
parent a344a70c59
commit 091b4567c9
13 changed files with 44 additions and 46 deletions

View File

@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding); MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) { Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r); float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < extents.cols; c++) {
@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding); MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area(); const float scale = SfpStream::kMax / extents.Area();
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) { Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r); float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < extents.cols; c++) {

View File

@ -278,7 +278,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Note that 2D parallelism is not worth the fork/join overhead because the // Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight. // tasks are very lightweight.
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV, /*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR { [&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads; const size_t head = task % kv_heads;

View File

@ -85,7 +85,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
{ {
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
// Better than kFlat. // Better than kFlat.
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func); /*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
} }
} }
@ -124,7 +124,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
{ {
// kHierarchical is not worth the extra sync overhead because the tasks are // kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight. // very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding, /*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
func); func);
} }
@ -619,7 +619,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const hwy::Divisor div_qbatch(qbatch.Size()); const hwy::Divisor div_qbatch(qbatch.Size());
// Compress q to q_bf. // Compress q to q_bf.
ParallelFor( ParallelFor(
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx, Parallelism::kWithinCluster, activations.q.Rows(), ctx,
/*cluster_idx=*/0, Callers::kFlashAttention, /*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t row, size_t worker) { [&](size_t row, size_t worker) {
CompressPerThread tls; CompressPerThread tls;

View File

@ -70,7 +70,7 @@ template <class Mat>
void ActivationBatched( void ActivationBatched(
ActivationType activation, Mat& c1, ThreadingContext& ctx, ActivationType activation, Mat& c1, ThreadingContext& ctx,
size_t cluster_idx = 0, size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { Parallelism parallelism = Parallelism::kFlat) {
using T = typename Mat::T; using T = typename Mat::T;
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
Callers::kActivationBatched, [&](uint64_t task, size_t worker) { Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
@ -115,7 +115,7 @@ template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched( HWY_NOINLINE void ActivationBatched(
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
size_t cluster_idx = 0, size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { Parallelism parallelism = Parallelism::kFlat) {
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,

View File

@ -426,7 +426,7 @@ static void SampleAndStream(const ModelConfig& config,
timing_info.NotifyGenerated(non_eos.Count()); timing_info.NotifyGenerated(non_eos.Count());
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, Parallelism::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, Callers::kSampleAndStream, /*cluster_idx=*/0, Callers::kSampleAndStream,
[&](size_t qi, size_t worker) { [&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return; if (!non_eos.Get(qi)) return;

View File

@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) { ThreadingContext& ctx) {
const size_t cluster_idx = 0; const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx); GetLayer(layer)->Fixup(mat_owners, ctx);
}); });
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx); VitLayer(layer)->Fixup(mat_owners, ctx);
}); });
@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
// Allocate in parallel because faulting in large tensors is slow. // Allocate in parallel because faulting in large tensors is slow.
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) { Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
TensorToRead& tensor = tensors[task]; TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat; MatPtr& mat = *tensor.mat;
@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors, static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) { const BlobReader& reader, ThreadingContext& ctx) {
// Especially TSAN is slow enough to warrant hierarchical parallelism. // Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD const Parallelism parallelism =
? ParallelismStrategy::kHierarchical HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
: ParallelismStrategy::kFlat; ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) { Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16); GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
const TensorToRead& tensor = tensors[task]; const TensorToRead& tensor = tensors[task];
@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches, const std::vector<IOBatch>& batches,
ThreadingContext& ctx) { ThreadingContext& ctx) {
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx, ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
/*cluster_idx=*/0, Callers::kReadBatches, /*cluster_idx=*/0, Callers::kReadBatches,
[&](uint64_t task, size_t thread) { [&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches); GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);

View File

@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
ThreadingContext& ctx, size_t cluster_idx) { ThreadingContext& ctx, size_t cluster_idx) {
HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size());
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) { cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size()); HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes, reader.file().Read(ranges[i].offset, ranges[i].bytes,
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
ctx.pools.NumClusters()); ctx.pools.NumClusters());
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest, ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
[&](const size_t task, size_t cluster_idx) { [&](const size_t task, size_t cluster_idx) {
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
task ? blobs1 : blobs2, ctx, cluster_idx); task ? blobs1 : blobs2, ctx, cluster_idx);
@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
std::atomic<size_t> blobs_equal{}; std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{}; std::atomic<size_t> blobs_diff{};
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
Callers::kTest, [&](size_t i, size_t /*thread*/) { Callers::kTest, [&](size_t i, size_t /*thread*/) {
const size_t mismatches = const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]); BlobDifferences(blobs1[i], blobs2[i], keys[i]);

View File

@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes, EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
static_cast<const uint8_t*>(data), writes); static_cast<const uint8_t*>(data), writes);
const ParallelismStrategy strategy = file_->IsAppendOnly() const Parallelism parallelism =
? ParallelismStrategy::kNone file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
: ParallelismStrategy::kFlat;
ParallelFor( ParallelFor(
strategy, writes.size(), ctx_, parallelism, writes.size(), ctx_,
/*cluster_idx=*/0, Callers::kBlobWriter, /*cluster_idx=*/0, Callers::kBlobWriter,
[this, &writes](uint64_t i, size_t /*thread*/) { [this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range; const BlobRange& range = writes[i].range;

View File

@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0, Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](uint64_t i, size_t /*thread*/) { Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
std::to_string(i).c_str()); std::to_string(i).c_str());

View File

@ -1126,7 +1126,7 @@ void TestAllDot() {
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;
ParallelFor( ParallelFor(
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest, Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
[&](size_t rep, size_t thread) { [&](size_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread); float* HWY_RESTRICT pb = b.Row(thread);

View File

@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
// Policy classes for parallelism, implementing some of `ParallelismStrategy`. // Policy classes for parallelism, implementing some of `Parallelism`.
struct MMParallelNone { struct MMParallelNone {
template <class Func> template <class Func>
@ -220,14 +220,14 @@ struct MMParallelHierarchical {
}; };
template <class Func, typename... Args> template <class Func, typename... Args>
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func, void DispatchParallelism(Parallelism parallelism, const Func& func,
Args&&... args) { Args&&... args) {
switch (parallelism) { switch (parallelism) {
case ParallelismStrategy::kNone: case Parallelism::kNone:
return func(MMParallelNone(), std::forward<Args>(args)...); return func(MMParallelNone(), std::forward<Args>(args)...);
case ParallelismStrategy::kWithinCluster: case Parallelism::kWithinCluster:
return func(MMParallelWithinCluster(), std::forward<Args>(args)...); return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
case ParallelismStrategy::kHierarchical: case Parallelism::kHierarchical:
return func(MMParallelHierarchical(), std::forward<Args>(args)...); return func(MMParallelHierarchical(), std::forward<Args>(args)...);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
@ -716,7 +716,7 @@ class MMOptions {
const void* opaque = nullptr; const void* opaque = nullptr;
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; Parallelism parallelism = Parallelism::kHierarchical;
}; };
// Arguments to MatMul() that are independent of the A/B/C types. Reduces // Arguments to MatMul() that are independent of the A/B/C types. Reduces

View File

@ -500,7 +500,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
activations.DebugCheckSameShape(out); activations.DebugCheckSameShape(out);
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, ParallelFor(Parallelism::kFlat, activations.Rows(), ctx,
cluster_idx, Callers::kOpsRMSNormBatched, cluster_idx, Callers::kOpsRMSNormBatched,
[&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
@ -517,7 +517,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
HWY_DASSERT(weights.Cols() == inout.Cols()); HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx,
Callers::kOpsRMSNormInplaceBatched, Callers::kOpsRMSNormInplaceBatched,
[&](uint64_t token_idx, size_t worker) { [&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
@ -550,7 +550,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
size_t cluster_idx = 0) { size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
ParallelFor( ParallelFor(
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, Parallelism::kFlat, out.Rows(), ctx, cluster_idx,
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) { Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker); AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
}); });
@ -1290,7 +1290,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos, const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx, size_t cluster_idx = 0) { ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return; if (cap == 0.0f) return;
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx,
Callers::kOpsMaybeLogitsSoftCapBatched, Callers::kOpsMaybeLogitsSoftCapBatched,
[&](uint64_t task, size_t worker) { [&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) { if (non_eos.Get(task)) {

View File

@ -100,7 +100,7 @@ struct ThreadingContext {
// Returns a worker index compatible with those from `ParallelFor`, assuming // Returns a worker index compatible with those from `ParallelFor`, assuming
// the current thread is running on one thread per cluster, which happens // the current thread is running on one thread per cluster, which happens
// when `ParallelismStrategy` is `kAcrossClusters`. // when `Parallelism` is `kAcrossClusters`.
size_t Worker(size_t cluster_idx) const { size_t Worker(size_t cluster_idx) const {
return cluster_idx * pools.MaxWorkersPerCluster(); return cluster_idx * pools.MaxWorkersPerCluster();
} }
@ -130,7 +130,7 @@ struct ThreadingContext {
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum)) PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
// Describes the strategy for distributing parallel work across cores. // Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t { enum class Parallelism : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker` // Execute using a single-threaded loop on the calling thread. The `worker`
// index passed to the user's `Func` is unique across clusters. // index passed to the user's `Func` is unique across clusters.
kNone, kNone,
@ -245,19 +245,19 @@ void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for // `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown. // `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
template <class Func> template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, void ParallelFor(Parallelism parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, Callers callers, ThreadingContext& ctx, size_t cluster_idx, Callers callers,
const Func& func) { const Func& func) {
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
if (cluster_idx != 0) { if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes. // If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone || HWY_DASSERT(parallelism == Parallelism::kNone ||
parallelism == ParallelismStrategy::kWithinCluster); parallelism == Parallelism::kWithinCluster);
} }
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
switch (parallelism) { switch (parallelism) {
case ParallelismStrategy::kNone: { case Parallelism::kNone: {
const size_t worker = ctx.Worker(cluster_idx); const size_t worker = ctx.Worker(cluster_idx);
for (size_t task = 0; task < num_tasks; ++task) { for (size_t task = 0; task < num_tasks; ++task) {
func(task, worker); func(task, worker);
@ -265,16 +265,16 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
return; return;
} }
case ParallelismStrategy::kAcrossClusters: case Parallelism::kAcrossClusters:
return ParallelForAcrossClusters( return ParallelForAcrossClusters(
num_tasks, ctx, caller, num_tasks, ctx, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster: case Parallelism::kWithinCluster:
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
func); func);
case ParallelismStrategy::kFlat: case Parallelism::kFlat:
// Choose a single pool: the only cluster, or across all clusters // Choose a single pool: the only cluster, or across all clusters
// (slower synchronization, but more memory bandwidth) // (slower synchronization, but more memory bandwidth)
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) { if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
@ -286,7 +286,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
func(task, ctx.Worker(cluster_idx)); func(task, ctx.Worker(cluster_idx));
}); });
case ParallelismStrategy::kHierarchical: case Parallelism::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx, callers, func); return HierarchicalParallelFor(num_tasks, ctx, callers, func);
} }
} }