From 091b4567c9fe3574204698dc031299215fe27259 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 6 Nov 2025 06:55:37 -0800 Subject: [PATCH] Minor: ParallelismStrategy->Parallelism PiperOrigin-RevId: 828936578 --- compression/test_util-inl.h | 4 ++-- gemma/attention.cc | 2 +- gemma/flash_attention.cc | 6 +++--- gemma/gemma-inl.h | 4 ++-- gemma/gemma.cc | 2 +- gemma/weights.cc | 15 +++++++-------- io/blob_compare.cc | 6 +++--- io/blob_store.cc | 7 +++---- io/blob_store_test.cc | 2 +- ops/dot_test.cc | 2 +- ops/matmul.h | 12 ++++++------ ops/ops-inl.h | 8 ++++---- util/threading_context.h | 20 ++++++++++---------- 13 files changed, 44 insertions(+), 46 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 81442f7..f2c8b8c 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -105,7 +105,7 @@ MatStorageT GenerateMat(const Extents2D& extents, MatPadding padding, MatStorageT raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT compressed("mat", extents, ctx.allocator, padding); const float scale = SfpStream::kMax / extents.Area(); - ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, + ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0, Callers::kTest, [&](size_t r, size_t thread) { float* HWY_RESTRICT row = raw.Row(r); for (size_t c = 0; c < extents.cols; c++) { @@ -134,7 +134,7 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, MatStorageT raw("raw", extents, ctx.allocator, MatPadding::kPacked); MatStorageT compressed("trans", extents, ctx.allocator, padding); const float scale = SfpStream::kMax / extents.Area(); - ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0, + ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0, Callers::kTest, [&](size_t r, size_t thread) { float* HWY_RESTRICT row = raw.Row(r); for (size_t c = 0; c < extents.cols; c++) { diff --git a/gemma/attention.cc b/gemma/attention.cc index ad464e7..854a489 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -278,7 +278,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Note that 2D parallelism is not worth the fork/join overhead because the // tasks are very lightweight. ParallelFor( - ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx, + Parallelism::kFlat, kv_heads * num_interleaved, env.ctx, /*cluster_idx=*/0, Callers::kAttComputeQKV, [&](size_t task, size_t worker) HWY_ATTR { const size_t head = task % kv_heads; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index b9b0c8a..803067b 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -85,7 +85,7 @@ static void TransposeQ(const MatPtrT& q, MatPtrT& q_t, { const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF); // Better than kFlat. - ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx, + ParallelFor(Parallelism::kHierarchical, num_tasks, ctx, /*cluster_idx=*/0, Callers::kFlashTransposeQ, func); } } @@ -124,7 +124,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch, { // kHierarchical is not worth the extra sync overhead because the tasks are // very lightweight. - ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx, + ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx, /*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding, func); } @@ -619,7 +619,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism, const hwy::Divisor div_qbatch(qbatch.Size()); // Compress q to q_bf. ParallelFor( - ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx, + Parallelism::kWithinCluster, activations.q.Rows(), ctx, /*cluster_idx=*/0, Callers::kFlashAttention, [&](size_t row, size_t worker) { CompressPerThread tls; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index dc7efea..93f8928 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -70,7 +70,7 @@ template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { + Parallelism parallelism = Parallelism::kFlat) { using T = typename Mat::T; ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, Callers::kActivationBatched, [&](uint64_t task, size_t worker) { @@ -115,7 +115,7 @@ template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, size_t cluster_idx = 0, - ParallelismStrategy parallelism = ParallelismStrategy::kFlat) { + Parallelism parallelism = Parallelism::kFlat) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, diff --git a/gemma/gemma.cc b/gemma/gemma.cc index ae0c6c1..2f342bf 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -426,7 +426,7 @@ static void SampleAndStream(const ModelConfig& config, timing_info.NotifyGenerated(non_eos.Count()); ParallelFor( - ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, + Parallelism::kFlat, qbatch.Size(), env.ctx, /*cluster_idx=*/0, Callers::kSampleAndStream, [&](size_t qi, size_t worker) { if (!non_eos.Get(qi)) return; diff --git a/gemma/weights.cc b/gemma/weights.cc index e1e01bf..00c12c6 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { void WeightsPtrs::Fixup(std::vector& mat_owners, ThreadingContext& ctx) { const size_t cluster_idx = 0; - ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx, Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { GetLayer(layer)->Fixup(mat_owners, ctx); }); - ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx, Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) { VitLayer(layer)->Fixup(mat_owners, ctx); }); @@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector& tensors, // Allocate in parallel because faulting in large tensors is slow. ParallelFor( - ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, + Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0, Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) { TensorToRead& tensor = tensors[task]; MatPtr& mat = *tensor.mat; @@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat, static void ReadAllToBF16(const std::vector& tensors, const BlobReader& reader, ThreadingContext& ctx) { // Especially TSAN is slow enough to warrant hierarchical parallelism. - const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD - ? ParallelismStrategy::kHierarchical - : ParallelismStrategy::kFlat; - ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0, + const Parallelism parallelism = + HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat; + ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0, Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) { GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16); const TensorToRead& tensor = tensors[task]; @@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader, const std::vector& batches, ThreadingContext& ctx) { // >5x speedup from parallel reads when cached. - ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx, + ParallelFor(Parallelism::kHierarchical, batches.size(), ctx, /*cluster_idx=*/0, Callers::kReadBatches, [&](uint64_t task, size_t thread) { GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches); diff --git a/io/blob_compare.cc b/io/blob_compare.cc index 30a2199..9bb860e 100644 --- a/io/blob_compare.cc +++ b/io/blob_compare.cc @@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs, ThreadingContext& ctx, size_t cluster_idx) { HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size()); - ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx, + ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx, cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) { HWY_ASSERT(ranges[i].bytes == blobs[i].size()); reader.file().Read(ranges[i].offset, ranges[i].bytes, @@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2, const double t0 = hwy::platform::Now(); HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30, ctx.pools.NumClusters()); - ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest, + ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest, [&](const size_t task, size_t cluster_idx) { ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2, task ? blobs1 : blobs2, ctx, cluster_idx); @@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2, const double t0 = hwy::platform::Now(); std::atomic blobs_equal{}; std::atomic blobs_diff{}; - ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0, + ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0, Callers::kTest, [&](size_t i, size_t /*thread*/) { const size_t mismatches = BlobDifferences(blobs1[i], blobs2[i], keys[i]); diff --git a/io/blob_store.cc b/io/blob_store.cc index af9f81d..8346e4b 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) { EnqueueChunks(keys_.size() - 1, curr_offset_, bytes, static_cast(data), writes); - const ParallelismStrategy strategy = file_->IsAppendOnly() - ? ParallelismStrategy::kNone - : ParallelismStrategy::kFlat; + const Parallelism parallelism = + file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat; ParallelFor( - strategy, writes.size(), ctx_, + parallelism, writes.size(), ctx_, /*cluster_idx=*/0, Callers::kBlobWriter, [this, &writes](uint64_t i, size_t /*thread*/) { const BlobRange& range = writes[i].range; diff --git a/io/blob_store_test.cc b/io/blob_store_test.cc index bb41c7e..cf96684 100644 --- a/io/blob_store_test.cc +++ b/io/blob_store_test.cc @@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) { HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); ParallelFor( - ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0, + Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0, Callers::kTest, [&](uint64_t i, size_t /*thread*/) { HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str()); diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 827b6b4..5547e86 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1126,7 +1126,7 @@ void TestAllDot() { std::array all_stats; ParallelFor( - ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest, + Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest, [&](size_t rep, size_t thread) { float* HWY_RESTRICT pa = a.Row(thread); float* HWY_RESTRICT pb = b.Row(thread); diff --git a/ops/matmul.h b/ops/matmul.h index fb29bcc..85deb62 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; -// Policy classes for parallelism, implementing some of `ParallelismStrategy`. +// Policy classes for parallelism, implementing some of `Parallelism`. struct MMParallelNone { template @@ -220,14 +220,14 @@ struct MMParallelHierarchical { }; template -void DispatchParallelism(ParallelismStrategy parallelism, const Func& func, +void DispatchParallelism(Parallelism parallelism, const Func& func, Args&&... args) { switch (parallelism) { - case ParallelismStrategy::kNone: + case Parallelism::kNone: return func(MMParallelNone(), std::forward(args)...); - case ParallelismStrategy::kWithinCluster: + case Parallelism::kWithinCluster: return func(MMParallelWithinCluster(), std::forward(args)...); - case ParallelismStrategy::kHierarchical: + case Parallelism::kHierarchical: return func(MMParallelHierarchical(), std::forward(args)...); default: HWY_UNREACHABLE; @@ -716,7 +716,7 @@ class MMOptions { const void* opaque = nullptr; uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. - ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; + Parallelism parallelism = Parallelism::kHierarchical; }; // Arguments to MatMul() that are independent of the A/B/C types. Reduces diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 80183bb..7ad8e20 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -500,7 +500,7 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, activations.DebugCheckSameShape(out); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx, + ParallelFor(Parallelism::kFlat, activations.Rows(), ctx, cluster_idx, Callers::kOpsRMSNormBatched, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), @@ -517,7 +517,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx, Callers::kOpsRMSNormInplaceBatched, [&](uint64_t token_idx, size_t worker) { RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, @@ -550,7 +550,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); ParallelFor( - ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx, + Parallelism::kFlat, out.Rows(), ctx, cluster_idx, Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) { AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker); }); @@ -1290,7 +1290,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; - ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx, + ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx, Callers::kOpsMaybeLogitsSoftCapBatched, [&](uint64_t task, size_t worker) { if (non_eos.Get(task)) { diff --git a/util/threading_context.h b/util/threading_context.h index b3e2a52..7a5c3f5 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -100,7 +100,7 @@ struct ThreadingContext { // Returns a worker index compatible with those from `ParallelFor`, assuming // the current thread is running on one thread per cluster, which happens - // when `ParallelismStrategy` is `kAcrossClusters`. + // when `Parallelism` is `kAcrossClusters`. size_t Worker(size_t cluster_idx) const { return cluster_idx * pools.MaxWorkersPerCluster(); } @@ -130,7 +130,7 @@ struct ThreadingContext { PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum)) // Describes the strategy for distributing parallel work across cores. -enum class ParallelismStrategy : uint8_t { +enum class Parallelism : uint8_t { // Execute using a single-threaded loop on the calling thread. The `worker` // index passed to the user's `Func` is unique across clusters. kNone, @@ -245,19 +245,19 @@ void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx, // `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for // `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown. template -void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, +void ParallelFor(Parallelism parallelism, size_t num_tasks, ThreadingContext& ctx, size_t cluster_idx, Callers callers, const Func& func) { HWY_DASSERT(cluster_idx < ctx.topology.NumClusters()); if (cluster_idx != 0) { // If already running across clusters, only use within-cluster modes. - HWY_DASSERT(parallelism == ParallelismStrategy::kNone || - parallelism == ParallelismStrategy::kWithinCluster); + HWY_DASSERT(parallelism == Parallelism::kNone || + parallelism == Parallelism::kWithinCluster); } const hwy::pool::Caller caller = ctx.pool_callers.Get(callers); switch (parallelism) { - case ParallelismStrategy::kNone: { + case Parallelism::kNone: { const size_t worker = ctx.Worker(cluster_idx); for (size_t task = 0; task < num_tasks; ++task) { func(task, worker); @@ -265,16 +265,16 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, return; } - case ParallelismStrategy::kAcrossClusters: + case Parallelism::kAcrossClusters: return ParallelForAcrossClusters( num_tasks, ctx, caller, [&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); }); - case ParallelismStrategy::kWithinCluster: + case Parallelism::kWithinCluster: return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller, func); - case ParallelismStrategy::kFlat: + case Parallelism::kFlat: // Choose a single pool: the only cluster, or across all clusters // (slower synchronization, but more memory bandwidth) if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) { @@ -286,7 +286,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, func(task, ctx.Worker(cluster_idx)); }); - case ParallelismStrategy::kHierarchical: + case Parallelism::kHierarchical: return HierarchicalParallelFor(num_tasks, ctx, callers, func); } }