Minor: ParallelismStrategy->Parallelism

PiperOrigin-RevId: 828936578
This commit is contained in:
Jan Wassenberg 2025-11-06 06:55:37 -08:00 committed by Copybara-Service
parent a344a70c59
commit 091b4567c9
13 changed files with 44 additions and 46 deletions

View File

@ -105,7 +105,7 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents, MatPadding padding,
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area();
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
@ -134,7 +134,7 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
MatStorageT<float> raw("raw", extents, ctx.allocator, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, ctx.allocator, padding);
const float scale = SfpStream::kMax / extents.Area();
ParallelFor(ParallelismStrategy::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
ParallelFor(Parallelism::kFlat, extents.rows, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](size_t r, size_t thread) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {

View File

@ -278,7 +278,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
// Note that 2D parallelism is not worth the fork/join overhead because the
// tasks are very lightweight.
ParallelFor(
ParallelismStrategy::kFlat, kv_heads * num_interleaved, env.ctx,
Parallelism::kFlat, kv_heads * num_interleaved, env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t head = task % kv_heads;

View File

@ -85,7 +85,7 @@ static void TransposeQ(const MatPtrT<float>& q, MatPtrT<BF16>& q_t,
{
const size_t num_tasks = hwy::DivCeil(q_t.Rows(), kNF);
// Better than kFlat.
ParallelFor(ParallelismStrategy::kHierarchical, num_tasks, ctx,
ParallelFor(Parallelism::kHierarchical, num_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashTransposeQ, func);
}
}
@ -124,7 +124,7 @@ void RMSNormAndPositionalEncoding(const size_t num_tokens, const QBatch& qbatch,
{
// kHierarchical is not worth the extra sync overhead because the tasks are
// very lightweight.
ParallelFor(ParallelismStrategy::kFlat, num_tokens * qbatch.Size(), ctx,
ParallelFor(Parallelism::kFlat, num_tokens * qbatch.Size(), ctx,
/*cluster_idx=*/0, Callers::kFlashRMSNormAndPositionalEncoding,
func);
}
@ -619,7 +619,7 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
const hwy::Divisor div_qbatch(qbatch.Size());
// Compress q to q_bf.
ParallelFor(
ParallelismStrategy::kWithinCluster, activations.q.Rows(), ctx,
Parallelism::kWithinCluster, activations.q.Rows(), ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t row, size_t worker) {
CompressPerThread tls;

View File

@ -70,7 +70,7 @@ template <class Mat>
void ActivationBatched(
ActivationType activation, Mat& c1, ThreadingContext& ctx,
size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
Parallelism parallelism = Parallelism::kFlat) {
using T = typename Mat::T;
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,
Callers::kActivationBatched, [&](uint64_t task, size_t worker) {
@ -115,7 +115,7 @@ template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched(
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
size_t cluster_idx = 0,
ParallelismStrategy parallelism = ParallelismStrategy::kFlat) {
Parallelism parallelism = Parallelism::kFlat) {
HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) {
ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx,

View File

@ -426,7 +426,7 @@ static void SampleAndStream(const ModelConfig& config,
timing_info.NotifyGenerated(non_eos.Count());
ParallelFor(
ParallelismStrategy::kFlat, qbatch.Size(), env.ctx,
Parallelism::kFlat, qbatch.Size(), env.ctx,
/*cluster_idx=*/0, Callers::kSampleAndStream,
[&](size_t qi, size_t worker) {
if (!non_eos.Get(qi)) return;

View File

@ -431,12 +431,12 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
const size_t cluster_idx = 0;
ParallelFor(ParallelismStrategy::kFlat, c_layers.size(), ctx, cluster_idx,
ParallelFor(Parallelism::kFlat, c_layers.size(), ctx, cluster_idx,
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
GetLayer(layer)->Fixup(mat_owners, ctx);
});
ParallelFor(ParallelismStrategy::kFlat, vit_layers.size(), ctx, cluster_idx,
ParallelFor(Parallelism::kFlat, vit_layers.size(), ctx, cluster_idx,
Callers::kFixupWeights, [&](uint64_t layer, size_t /*worker*/) {
VitLayer(layer)->Fixup(mat_owners, ctx);
});
@ -527,7 +527,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
// Allocate in parallel because faulting in large tensors is slow.
ParallelFor(
ParallelismStrategy::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
Parallelism::kFlat, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kAllocateAndBindAll, [&](uint64_t task, size_t /*thread*/) {
TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
@ -586,10 +586,9 @@ static void DecompressToBF16(MatPtr& mat,
static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) {
// Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
? ParallelismStrategy::kHierarchical
: ParallelismStrategy::kFlat;
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
const Parallelism parallelism =
HWY_IS_DEBUG_BUILD ? Parallelism::kHierarchical : Parallelism::kFlat;
ParallelFor(parallelism, tensors.size(), ctx, /*cluster_idx=*/0,
Callers::kReadAllToBF16, [&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadAllToBF16);
const TensorToRead& tensor = tensors[task];
@ -677,7 +676,7 @@ static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches,
ThreadingContext& ctx) {
// >5x speedup from parallel reads when cached.
ParallelFor(ParallelismStrategy::kHierarchical, batches.size(), ctx,
ParallelFor(Parallelism::kHierarchical, batches.size(), ctx,
/*cluster_idx=*/0, Callers::kReadBatches,
[&](uint64_t task, size_t thread) {
GCPP_ZONE(ctx, thread, Zones::kStartupWeightsReadBatches);

View File

@ -106,7 +106,7 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
ThreadingContext& ctx, size_t cluster_idx) {
HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size());
ParallelFor(ParallelismStrategy::kWithinCluster, blobs.size(), ctx,
ParallelFor(Parallelism::kWithinCluster, blobs.size(), ctx,
cluster_idx, Callers::kTest, [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.file().Read(ranges[i].offset, ranges[i].bytes,
@ -122,7 +122,7 @@ void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const double t0 = hwy::platform::Now();
HWY_WARN("Reading %zu GiB, %zu clusters: ", total_bytes >> 30,
ctx.pools.NumClusters());
ParallelFor(ParallelismStrategy::kAcrossClusters, 2, ctx, 0, Callers::kTest,
ParallelFor(Parallelism::kAcrossClusters, 2, ctx, 0, Callers::kTest,
[&](const size_t task, size_t cluster_idx) {
ReadBlobs(task ? reader1 : reader2, task ? ranges1 : ranges2,
task ? blobs1 : blobs2, ctx, cluster_idx);
@ -189,7 +189,7 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
const double t0 = hwy::platform::Now();
std::atomic<size_t> blobs_equal{};
std::atomic<size_t> blobs_diff{};
ParallelFor(ParallelismStrategy::kHierarchical, keys.size(), ctx, 0,
ParallelFor(Parallelism::kHierarchical, keys.size(), ctx, 0,
Callers::kTest, [&](size_t i, size_t /*thread*/) {
const size_t mismatches =
BlobDifferences(blobs1[i], blobs2[i], keys[i]);

View File

@ -488,11 +488,10 @@ void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
EnqueueChunks(keys_.size() - 1, curr_offset_, bytes,
static_cast<const uint8_t*>(data), writes);
const ParallelismStrategy strategy = file_->IsAppendOnly()
? ParallelismStrategy::kNone
: ParallelismStrategy::kFlat;
const Parallelism parallelism =
file_->IsAppendOnly() ? Parallelism::kNone : Parallelism::kFlat;
ParallelFor(
strategy, writes.size(), ctx_,
parallelism, writes.size(), ctx_,
/*cluster_idx=*/0, Callers::kBlobWriter,
[this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range;

View File

@ -130,7 +130,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
ParallelFor(
ParallelismStrategy::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
Parallelism::kFlat, num_blobs, ctx, /*cluster_idx=*/0,
Callers::kTest, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(),
std::to_string(i).c_str());

View File

@ -1126,7 +1126,7 @@ void TestAllDot() {
std::array<DotStats, kMaxWorkers> all_stats;
ParallelFor(
ParallelismStrategy::kWithinCluster, kReps, ctx, 0, Callers::kTest,
Parallelism::kWithinCluster, kReps, ctx, 0, Callers::kTest,
[&](size_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread);

View File

@ -61,7 +61,7 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384;
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
// Policy classes for parallelism, implementing some of `ParallelismStrategy`.
// Policy classes for parallelism, implementing some of `Parallelism`.
struct MMParallelNone {
template <class Func>
@ -220,14 +220,14 @@ struct MMParallelHierarchical {
};
template <class Func, typename... Args>
void DispatchParallelism(ParallelismStrategy parallelism, const Func& func,
void DispatchParallelism(Parallelism parallelism, const Func& func,
Args&&... args) {
switch (parallelism) {
case ParallelismStrategy::kNone:
case Parallelism::kNone:
return func(MMParallelNone(), std::forward<Args>(args)...);
case ParallelismStrategy::kWithinCluster:
case Parallelism::kWithinCluster:
return func(MMParallelWithinCluster(), std::forward<Args>(args)...);
case ParallelismStrategy::kHierarchical:
case Parallelism::kHierarchical:
return func(MMParallelHierarchical(), std::forward<Args>(args)...);
default:
HWY_UNREACHABLE;
@ -716,7 +716,7 @@ class MMOptions {
const void* opaque = nullptr;
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
Parallelism parallelism = Parallelism::kHierarchical;
};
// Arguments to MatMul() that are independent of the A/B/C types. Reduces

View File

@ -500,7 +500,7 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
activations.DebugCheckSameShape(out);
CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, activations.Rows(), ctx,
ParallelFor(Parallelism::kFlat, activations.Rows(), ctx,
cluster_idx, Callers::kOpsRMSNormBatched,
[&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
@ -517,7 +517,7 @@ void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) {
ParallelFor(ParallelismStrategy::kFlat, inout.Rows(), ctx, cluster_idx,
ParallelFor(Parallelism::kFlat, inout.Rows(), ctx, cluster_idx,
Callers::kOpsRMSNormInplaceBatched,
[&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
@ -550,7 +550,7 @@ static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x));
ParallelFor(
ParallelismStrategy::kFlat, out.Rows(), ctx, cluster_idx,
Parallelism::kFlat, out.Rows(), ctx, cluster_idx,
Callers::kOpsAddFromBatched, [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), ctx, worker);
});
@ -1290,7 +1290,7 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return;
ParallelFor(ParallelismStrategy::kFlat, x.Rows(), ctx, cluster_idx,
ParallelFor(Parallelism::kFlat, x.Rows(), ctx, cluster_idx,
Callers::kOpsMaybeLogitsSoftCapBatched,
[&](uint64_t task, size_t worker) {
if (non_eos.Get(task)) {

View File

@ -100,7 +100,7 @@ struct ThreadingContext {
// Returns a worker index compatible with those from `ParallelFor`, assuming
// the current thread is running on one thread per cluster, which happens
// when `ParallelismStrategy` is `kAcrossClusters`.
// when `Parallelism` is `kAcrossClusters`.
size_t Worker(size_t cluster_idx) const {
return cluster_idx * pools.MaxWorkersPerCluster();
}
@ -130,7 +130,7 @@ struct ThreadingContext {
PROFILER_ZONE3(ctx.profiler, global_idx, ctx.profiler_zones.Get(zone_enum))
// Describes the strategy for distributing parallel work across cores.
enum class ParallelismStrategy : uint8_t {
enum class Parallelism : uint8_t {
// Execute using a single-threaded loop on the calling thread. The `worker`
// index passed to the user's `Func` is unique across clusters.
kNone,
@ -245,19 +245,19 @@ void HierarchicalParallelFor(size_t num_tasks, ThreadingContext& ctx,
// `cluster_idx` for `kAcrossClusters`. The `cluster_idx` argument is for
// `parallelism == {kWithinCluster, kNone}`, and should be 0 if unknown.
template <class Func>
void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
void ParallelFor(Parallelism parallelism, size_t num_tasks,
ThreadingContext& ctx, size_t cluster_idx, Callers callers,
const Func& func) {
HWY_DASSERT(cluster_idx < ctx.topology.NumClusters());
if (cluster_idx != 0) {
// If already running across clusters, only use within-cluster modes.
HWY_DASSERT(parallelism == ParallelismStrategy::kNone ||
parallelism == ParallelismStrategy::kWithinCluster);
HWY_DASSERT(parallelism == Parallelism::kNone ||
parallelism == Parallelism::kWithinCluster);
}
const hwy::pool::Caller caller = ctx.pool_callers.Get(callers);
switch (parallelism) {
case ParallelismStrategy::kNone: {
case Parallelism::kNone: {
const size_t worker = ctx.Worker(cluster_idx);
for (size_t task = 0; task < num_tasks; ++task) {
func(task, worker);
@ -265,16 +265,16 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
return;
}
case ParallelismStrategy::kAcrossClusters:
case Parallelism::kAcrossClusters:
return ParallelForAcrossClusters(
num_tasks, ctx, caller,
[&](uint64_t task, size_t cluster_idx) { func(task, cluster_idx); });
case ParallelismStrategy::kWithinCluster:
case Parallelism::kWithinCluster:
return ParallelForWithinCluster(num_tasks, ctx, cluster_idx, caller,
func);
case ParallelismStrategy::kFlat:
case Parallelism::kFlat:
// Choose a single pool: the only cluster, or across all clusters
// (slower synchronization, but more memory bandwidth)
if (HWY_UNLIKELY(ctx.pools.NumClusters() == 1)) {
@ -286,7 +286,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
func(task, ctx.Worker(cluster_idx));
});
case ParallelismStrategy::kHierarchical:
case Parallelism::kHierarchical:
return HierarchicalParallelFor(num_tasks, ctx, callers, func);
}
}