diff --git a/BUILD.bazel b/BUILD.bazel index e141e95..ce4cffb 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -264,6 +264,7 @@ cc_library( ":allocator", ":basics", ":mat", + ":threading", ":threading_context", "//compression:compress", "@highway//:bit_set", diff --git a/gemma/attention.cc b/gemma/attention.cc index 13681d0..bd76329 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -233,9 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, { PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); - // Full parallelism is helpful, SmallParallelFor is insufficient. - ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, - ctx.pools, func); + // Full parallelism is helpful, kAcrossClusters is insufficient. + NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, + ctx.pools, func); } } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ed0750a..80ec0ee 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -66,31 +66,39 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1, // No C2 multiplier. template -void ActivationBatched(ActivationType activation, Mat& c1, - ThreadingContext& ctx) { +void ActivationBatched( + ActivationType activation, Mat& c1, ThreadingContext& ctx, + size_t cluster_idx = 0, + ParallelismType parallelism = ParallelismType::kAcrossClusters) { using T = typename Mat::T; - SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - // Cast to correct type so type deduction works. - Activation(activation, c1.Row(task), static_cast(nullptr), - c1.Cols(), ctx.profiler, worker); - }); + ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + [&](uint64_t task, size_t worker) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(task), + static_cast(nullptr), c1.Cols(), + ctx.profiler, worker); + }); } template -HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat1& c1, - const Mat2* c2, ThreadingContext& ctx) { +HWY_NOINLINE void ActivationBatched( + ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, + size_t cluster_idx = 0, + ParallelismType parallelism = ParallelismType::kAcrossClusters) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { - SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), - ctx.profiler, worker); - }); + ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), + ctx.profiler, worker); + }); } else { // No multiplier - SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - Activation(activation, c1.Row(task), - static_cast(nullptr), c1.Cols(), - ctx.profiler, worker); - }); + ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx, + [&](uint64_t task, size_t worker) { + Activation(activation, c1.Row(task), + static_cast(nullptr), + c1.Cols(), ctx.profiler, worker); + }); } } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 1d5dc5d..53dfb05 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -1155,12 +1155,13 @@ struct MMImpl { case ParallelismType::kSequential: MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, range_np)(MMSequentialPolicy(), A, B, C_rows); - case ParallelismType::kCluster: + case ParallelismType::kWithinCluster: MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, range_np)(MMClusterParallelPolicy(), A, B, C_rows); break; default: - HWY_ABORT("Parallelism type not implemented."); + HWY_ABORT("Parallelism type %s not implemented.", + static_cast(options.parallelism_type)); break; } } @@ -1189,7 +1190,7 @@ struct MMImpl { template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, - MatPtrT& C, MMOptions options) { + MatPtrT& C, MMOptions options = MMOptions()) { RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); const Allocator& allocator = env.ctx.allocator; diff --git a/ops/matmul.h b/ops/matmul.h index 5c526de..11262bc 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -27,6 +27,7 @@ // IWYU pragma: begin_exports #include "util/basics.h" #include "util/mat.h" +#include "util/threading.h" #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" @@ -56,16 +57,6 @@ static constexpr size_t kMaxMR = 4; IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, size_t N, size_t sizeof_TC, size_t nr); -enum class ParallelismType : uint8_t { - kNone, - // No parallelism. - kSequential, - // Parallelism at cluster level. - kCluster, - // Parallelism at package level. - kNested, -}; - struct MMOptions { ParallelismType parallelism_type = ParallelismType::kNested; uint8_t cluster_idx = 0; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 0438bf7..0173ee8 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -494,14 +494,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x, // Simple loops unless/until batch sizes are large enough to parallelize. template void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, - MatPtrT& out, ThreadingContext& ctx) { + MatPtrT& out, ThreadingContext& ctx, + size_t cluster_idx = 0) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(activations.SameShape(out)); CallUpcasted(&weights, [&](const auto* weights_t) { - SmallParallelFor( - activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { + ParallelFor( + ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools, + cluster_idx, [&](uint64_t token_idx, size_t worker) { RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), out.Row(token_idx), activations.Cols(), ctx.profiler, worker); }); @@ -510,13 +512,14 @@ void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, template void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout, - ThreadingContext& ctx) { + ThreadingContext& ctx, size_t cluster_idx = 0) { HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Cols() == inout.Cols()); CallUpcasted(&weights, [&](const auto* weights_t) { - SmallParallelFor( - inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { + ParallelFor( + ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx, + [&](uint64_t token_idx, size_t worker) { RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), inout.Cols(), ctx.profiler, worker); }); @@ -542,13 +545,14 @@ void LayerNormBatched(const MatPtrT& x, const MatPtr& weight, template static HWY_INLINE void AddFromBatched(const MatPtrT& x, MatPtrT& out, - ThreadingContext& ctx) { + ThreadingContext& ctx, + size_t cluster_idx = 0) { HWY_DASSERT(out.SameShape(x)); - SmallParallelFor(out.Rows(), ctx.pools, - [&](uint64_t token_idx, size_t worker) { - AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), - ctx.profiler, worker); - }); + ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools, + cluster_idx, [&](uint64_t token_idx, size_t worker) { + AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), + ctx.profiler, worker); + }); } template @@ -776,13 +780,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( const float cap, MatPtrT& x, const hwy::BitSet4096<>& non_eos, - ThreadingContext& ctx) { + ThreadingContext& ctx, size_t cluster_idx = 0) { if (cap == 0.0f) return; - SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { - if (non_eos.Get(task)) { - LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); - } - }); + ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools, + cluster_idx, [&](uint64_t task, size_t worker) { + if (non_eos.Get(task)) { + LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, + worker); + } + }); } static HWY_NOINLINE HWY_MAYBE_UNUSED size_t diff --git a/util/threading.h b/util/threading.h index 0a57ddb..ef4f1c7 100644 --- a/util/threading.h +++ b/util/threading.h @@ -326,7 +326,7 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1, // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // over clusters of ONE package, then within each cluster. template -void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { +void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { // Even if there are multiple packages, we only use the first. const size_t pkg_idx = 0; @@ -356,14 +356,57 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { }); } -// As above, but for lightweight tasks. Uses only one pool. -template -void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { - // Even if there are multiple packages, we only use the first. - const size_t pkg_idx = 0; +// Which pool(s) to use for parallelizing: +enum class ParallelismType : uint8_t { + // None: single-threaded loop on the calling thread. + kSequential, + // One thread per cluster within the first package; or one per core if there + // is only one cluster. Use for few or lightweight tasks, or to maximize + // memory bandwidth availability. + kAcrossClusters, + // All cores within the cluster identified by `cluster_idx`. Use if already + // within a `kAcrossClusters` parallel-for, or if latency is more important + // than memory bandwidth. + kWithinCluster, + // First statically partitions `kAcrossClusters`, then `kWithinCluster`. This + // utilizes all cores, but has higher fork-join overhead (two barriers); use + // if there are many or heavy tasks. + kNested, +}; - pools.Pool(pkg_idx).Run( - 0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); }); +// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the +// number/type of workers determined by `parallelism`. `cluster_idx` is only +// used if `parallelism == kWithinCluster`. +template +void ParallelFor(ParallelismType parallelism, size_t num_tasks, + NestedPools& pools, size_t cluster_idx, const Func& func) { + if (cluster_idx != 0) { + // If already running across clusters, must not use across-cluster modes. + HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters && + parallelism != ParallelismType::kNested); + } + + const size_t pkg_idx = 0; + switch (parallelism) { + case ParallelismType::kSequential: + for (size_t task = 0; task < num_tasks; ++task) { + func(task, /*worker=*/0); + } + return; + + case ParallelismType::kAcrossClusters: + return pools.Pool(pkg_idx).Run( + 0, num_tasks, + [&](uint64_t task, size_t worker) { func(task, worker); }); + + case ParallelismType::kWithinCluster: + return pools.Cluster(pkg_idx, cluster_idx) + .Run(0, num_tasks, + [&](uint64_t task, size_t worker) { func(task, worker); }); + + case ParallelismType::kNested: + return NestedParallelFor(num_tasks, pools, func); + } } } // namespace gcpp