Add ParallelFor wrapper function and one new mode

Move ParallelismType from matmul.h to threading.h
Replace SmallParallelFor with ParallelFor and the new mode

PiperOrigin-RevId: 802038452
This commit is contained in:
Jan Wassenberg 2025-09-02 01:39:28 -07:00 committed by Copybara-Service
parent 3737224132
commit 1e3c853e80
7 changed files with 110 additions and 60 deletions

View File

@ -264,6 +264,7 @@ cc_library(
":allocator", ":allocator",
":basics", ":basics",
":mat", ":mat",
":threading",
":threading_context", ":threading_context",
"//compression:compress", "//compression:compress",
"@highway//:bit_set", "@highway//:bit_set",

View File

@ -233,9 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
{ {
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin"); PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient. // Full parallelism is helpful, kAcrossClusters is insufficient.
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads, NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
ctx.pools, func); ctx.pools, func);
} }
} }

View File

@ -66,31 +66,39 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
// No C2 multiplier. // No C2 multiplier.
template <class Mat> template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1, void ActivationBatched(
ThreadingContext& ctx) { ActivationType activation, Mat& c1, ThreadingContext& ctx,
size_t cluster_idx = 0,
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
using T = typename Mat::T; using T = typename Mat::T;
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
// Cast to correct type so type deduction works. [&](uint64_t task, size_t worker) {
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr), // Cast to correct type so type deduction works.
c1.Cols(), ctx.profiler, worker); Activation(activation, c1.Row(task),
}); static_cast<const T*>(nullptr), c1.Cols(),
ctx.profiler, worker);
});
} }
template <class Mat1, class Mat2> template <class Mat1, class Mat2>
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat1& c1, HWY_NOINLINE void ActivationBatched(
const Mat2* c2, ThreadingContext& ctx) { ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
size_t cluster_idx = 0,
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
HWY_DASSERT(c1.SameShape(*c2)); HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) { if (c2 && c2->HasPtr()) {
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), [&](uint64_t task, size_t worker) {
ctx.profiler, worker); Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
}); ctx.profiler, worker);
});
} else { // No multiplier } else { // No multiplier
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
Activation(activation, c1.Row(task), [&](uint64_t task, size_t worker) {
static_cast<const typename Mat2::T*>(nullptr), c1.Cols(), Activation(activation, c1.Row(task),
ctx.profiler, worker); static_cast<const typename Mat2::T*>(nullptr),
}); c1.Cols(), ctx.profiler, worker);
});
} }
} }

View File

@ -1155,12 +1155,13 @@ struct MMImpl {
case ParallelismType::kSequential: case ParallelismType::kSequential:
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMSequentialPolicy(), A, B, C_rows); range_np)(MMSequentialPolicy(), A, B, C_rows);
case ParallelismType::kCluster: case ParallelismType::kWithinCluster:
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx, MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
range_np)(MMClusterParallelPolicy(), A, B, C_rows); range_np)(MMClusterParallelPolicy(), A, B, C_rows);
break; break;
default: default:
HWY_ABORT("Parallelism type not implemented."); HWY_ABORT("Parallelism type %s not implemented.",
static_cast<int>(options.parallelism_type));
break; break;
} }
} }
@ -1189,7 +1190,7 @@ struct MMImpl {
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C, MMOptions options) { MatPtrT<TC>& C, MMOptions options = MMOptions()) {
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]); RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
const Allocator& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;

View File

@ -27,6 +27,7 @@
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "util/basics.h" #include "util/basics.h"
#include "util/mat.h" #include "util/mat.h"
#include "util/threading.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" #include "hwy/base.h"
@ -56,16 +57,6 @@ static constexpr size_t kMaxMR = 4;
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t N, size_t sizeof_TC, size_t nr); size_t N, size_t sizeof_TC, size_t nr);
enum class ParallelismType : uint8_t {
kNone,
// No parallelism.
kSequential,
// Parallelism at cluster level.
kCluster,
// Parallelism at package level.
kNested,
};
struct MMOptions { struct MMOptions {
ParallelismType parallelism_type = ParallelismType::kNested; ParallelismType parallelism_type = ParallelismType::kNested;
uint8_t cluster_idx = 0; uint8_t cluster_idx = 0;

View File

@ -494,14 +494,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
// Simple loops unless/until batch sizes are large enough to parallelize. // Simple loops unless/until batch sizes are large enough to parallelize.
template <typename XT, typename OT> template <typename XT, typename OT>
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights, void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
MatPtrT<OT>& out, ThreadingContext& ctx) { MatPtrT<OT>& out, ThreadingContext& ctx,
size_t cluster_idx = 0) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols()); HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out)); HWY_DASSERT(activations.SameShape(out));
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
SmallParallelFor( ParallelFor(
activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools,
cluster_idx, [&](uint64_t token_idx, size_t worker) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
out.Row(token_idx), activations.Cols(), ctx.profiler, worker); out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
}); });
@ -510,13 +512,14 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
template <typename XT> template <typename XT>
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout, void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
ThreadingContext& ctx) { ThreadingContext& ctx, size_t cluster_idx = 0) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == inout.Cols()); HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) { CallUpcasted(&weights, [&](const auto* weights_t) {
SmallParallelFor( ParallelFor(
inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) { ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx,
[&](uint64_t token_idx, size_t worker) {
RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx), RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx),
inout.Cols(), ctx.profiler, worker); inout.Cols(), ctx.profiler, worker);
}); });
@ -542,13 +545,14 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
template <typename XT> template <typename XT>
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out, static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
ThreadingContext& ctx) { ThreadingContext& ctx,
size_t cluster_idx = 0) {
HWY_DASSERT(out.SameShape(x)); HWY_DASSERT(out.SameShape(x));
SmallParallelFor(out.Rows(), ctx.pools, ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools,
[&](uint64_t token_idx, size_t worker) { cluster_idx, [&](uint64_t token_idx, size_t worker) {
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(), AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
ctx.profiler, worker); ctx.profiler, worker);
}); });
} }
template <typename XT> template <typename XT>
@ -776,13 +780,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched( static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos, const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
ThreadingContext& ctx) { ThreadingContext& ctx, size_t cluster_idx = 0) {
if (cap == 0.0f) return; if (cap == 0.0f) return;
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) { ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools,
if (non_eos.Get(task)) { cluster_idx, [&](uint64_t task, size_t worker) {
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker); if (non_eos.Get(task)) {
} LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
}); worker);
}
});
} }
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t static HWY_NOINLINE HWY_MAYBE_UNUSED size_t

View File

@ -326,7 +326,7 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes // Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
// over clusters of ONE package, then within each cluster. // over clusters of ONE package, then within each cluster.
template <class Func> template <class Func>
void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
// Even if there are multiple packages, we only use the first. // Even if there are multiple packages, we only use the first.
const size_t pkg_idx = 0; const size_t pkg_idx = 0;
@ -356,14 +356,57 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
}); });
} }
// As above, but for lightweight tasks. Uses only one pool. // Which pool(s) to use for parallelizing:
template <class Func> enum class ParallelismType : uint8_t {
void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) { // None: single-threaded loop on the calling thread.
// Even if there are multiple packages, we only use the first. kSequential,
const size_t pkg_idx = 0; // One thread per cluster within the first package; or one per core if there
// is only one cluster. Use for few or lightweight tasks, or to maximize
// memory bandwidth availability.
kAcrossClusters,
// All cores within the cluster identified by `cluster_idx`. Use if already
// within a `kAcrossClusters` parallel-for, or if latency is more important
// than memory bandwidth.
kWithinCluster,
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
// utilizes all cores, but has higher fork-join overhead (two barriers); use
// if there are many or heavy tasks.
kNested,
};
pools.Pool(pkg_idx).Run( // Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); }); // number/type of workers determined by `parallelism`. `cluster_idx` is only
// used if `parallelism == kWithinCluster`.
template <class Func>
void ParallelFor(ParallelismType parallelism, size_t num_tasks,
NestedPools& pools, size_t cluster_idx, const Func& func) {
if (cluster_idx != 0) {
// If already running across clusters, must not use across-cluster modes.
HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters &&
parallelism != ParallelismType::kNested);
}
const size_t pkg_idx = 0;
switch (parallelism) {
case ParallelismType::kSequential:
for (size_t task = 0; task < num_tasks; ++task) {
func(task, /*worker=*/0);
}
return;
case ParallelismType::kAcrossClusters:
return pools.Pool(pkg_idx).Run(
0, num_tasks,
[&](uint64_t task, size_t worker) { func(task, worker); });
case ParallelismType::kWithinCluster:
return pools.Cluster(pkg_idx, cluster_idx)
.Run(0, num_tasks,
[&](uint64_t task, size_t worker) { func(task, worker); });
case ParallelismType::kNested:
return NestedParallelFor(num_tasks, pools, func);
}
} }
} // namespace gcpp } // namespace gcpp