mirror of https://github.com/google/gemma.cpp.git
Add ParallelFor wrapper function and one new mode
Move ParallelismType from matmul.h to threading.h Replace SmallParallelFor with ParallelFor and the new mode PiperOrigin-RevId: 802038452
This commit is contained in:
parent
3737224132
commit
1e3c853e80
|
|
@ -264,6 +264,7 @@ cc_library(
|
|||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"@highway//:bit_set",
|
||||
|
|
|
|||
|
|
@ -233,9 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
|||
|
||||
{
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||
ctx.pools, func);
|
||||
// Full parallelism is helpful, kAcrossClusters is insufficient.
|
||||
NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||
ctx.pools, func);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -66,31 +66,39 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
|||
|
||||
// No C2 multiplier.
|
||||
template <class Mat>
|
||||
void ActivationBatched(ActivationType activation, Mat& c1,
|
||||
ThreadingContext& ctx) {
|
||||
void ActivationBatched(
|
||||
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0,
|
||||
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
|
||||
using T = typename Mat::T;
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
// Cast to correct type so type deduction works.
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const T*>(nullptr), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <class Mat1, class Mat2>
|
||||
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat1& c1,
|
||||
const Mat2* c2, ThreadingContext& ctx) {
|
||||
HWY_NOINLINE void ActivationBatched(
|
||||
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0,
|
||||
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
|
||||
HWY_DASSERT(c1.SameShape(*c2));
|
||||
if (c2 && c2->HasPtr()) {
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
} else { // No multiplier
|
||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const typename Mat2::T*>(nullptr), c1.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
||||
[&](uint64_t task, size_t worker) {
|
||||
Activation(activation, c1.Row(task),
|
||||
static_cast<const typename Mat2::T*>(nullptr),
|
||||
c1.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1155,12 +1155,13 @@ struct MMImpl {
|
|||
case ParallelismType::kSequential:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||
range_np)(MMSequentialPolicy(), A, B, C_rows);
|
||||
case ParallelismType::kCluster:
|
||||
case ParallelismType::kWithinCluster:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Parallelism type not implemented.");
|
||||
HWY_ABORT("Parallelism type %s not implemented.",
|
||||
static_cast<int>(options.parallelism_type));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -1189,7 +1190,7 @@ struct MMImpl {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
MatPtrT<TC>& C, MMOptions options) {
|
||||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
||||
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
|
|
|
|||
11
ops/matmul.h
11
ops/matmul.h
|
|
@ -27,6 +27,7 @@
|
|||
// IWYU pragma: begin_exports
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
|
|
@ -56,16 +57,6 @@ static constexpr size_t kMaxMR = 4;
|
|||
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||
size_t N, size_t sizeof_TC, size_t nr);
|
||||
|
||||
enum class ParallelismType : uint8_t {
|
||||
kNone,
|
||||
// No parallelism.
|
||||
kSequential,
|
||||
// Parallelism at cluster level.
|
||||
kCluster,
|
||||
// Parallelism at package level.
|
||||
kNested,
|
||||
};
|
||||
|
||||
struct MMOptions {
|
||||
ParallelismType parallelism_type = ParallelismType::kNested;
|
||||
uint8_t cluster_idx = 0;
|
||||
|
|
|
|||
|
|
@ -494,14 +494,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
|||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||
template <typename XT, typename OT>
|
||||
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||
MatPtrT<OT>& out, ThreadingContext& ctx) {
|
||||
MatPtrT<OT>& out, ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||
HWY_DASSERT(activations.SameShape(out));
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
SmallParallelFor(
|
||||
activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
ParallelFor(
|
||||
ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools,
|
||||
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
|
|
@ -510,13 +512,14 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
|||
|
||||
template <typename XT>
|
||||
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||
ThreadingContext& ctx) {
|
||||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||
|
||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||
SmallParallelFor(
|
||||
inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
||||
ParallelFor(
|
||||
ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx),
|
||||
inout.Cols(), ctx.profiler, worker);
|
||||
});
|
||||
|
|
@ -542,13 +545,14 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
|
|||
|
||||
template <typename XT>
|
||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||
ThreadingContext& ctx) {
|
||||
ThreadingContext& ctx,
|
||||
size_t cluster_idx = 0) {
|
||||
HWY_DASSERT(out.SameShape(x));
|
||||
SmallParallelFor(out.Rows(), ctx.pools,
|
||||
[&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools,
|
||||
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||
ctx.profiler, worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XT>
|
||||
|
|
@ -776,13 +780,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
|||
|
||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||
ThreadingContext& ctx) {
|
||||
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||
if (cap == 0.0f) return;
|
||||
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker);
|
||||
}
|
||||
});
|
||||
ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools,
|
||||
cluster_idx, [&](uint64_t task, size_t worker) {
|
||||
if (non_eos.Get(task)) {
|
||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
|
||||
worker);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
|||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||
// over clusters of ONE package, then within each cluster.
|
||||
template <class Func>
|
||||
void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||
void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||
// Even if there are multiple packages, we only use the first.
|
||||
const size_t pkg_idx = 0;
|
||||
|
||||
|
|
@ -356,14 +356,57 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
|||
});
|
||||
}
|
||||
|
||||
// As above, but for lightweight tasks. Uses only one pool.
|
||||
template <class Func>
|
||||
void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||
// Even if there are multiple packages, we only use the first.
|
||||
const size_t pkg_idx = 0;
|
||||
// Which pool(s) to use for parallelizing:
|
||||
enum class ParallelismType : uint8_t {
|
||||
// None: single-threaded loop on the calling thread.
|
||||
kSequential,
|
||||
// One thread per cluster within the first package; or one per core if there
|
||||
// is only one cluster. Use for few or lightweight tasks, or to maximize
|
||||
// memory bandwidth availability.
|
||||
kAcrossClusters,
|
||||
// All cores within the cluster identified by `cluster_idx`. Use if already
|
||||
// within a `kAcrossClusters` parallel-for, or if latency is more important
|
||||
// than memory bandwidth.
|
||||
kWithinCluster,
|
||||
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
|
||||
// utilizes all cores, but has higher fork-join overhead (two barriers); use
|
||||
// if there are many or heavy tasks.
|
||||
kNested,
|
||||
};
|
||||
|
||||
pools.Pool(pkg_idx).Run(
|
||||
0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); });
|
||||
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
|
||||
// number/type of workers determined by `parallelism`. `cluster_idx` is only
|
||||
// used if `parallelism == kWithinCluster`.
|
||||
template <class Func>
|
||||
void ParallelFor(ParallelismType parallelism, size_t num_tasks,
|
||||
NestedPools& pools, size_t cluster_idx, const Func& func) {
|
||||
if (cluster_idx != 0) {
|
||||
// If already running across clusters, must not use across-cluster modes.
|
||||
HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters &&
|
||||
parallelism != ParallelismType::kNested);
|
||||
}
|
||||
|
||||
const size_t pkg_idx = 0;
|
||||
switch (parallelism) {
|
||||
case ParallelismType::kSequential:
|
||||
for (size_t task = 0; task < num_tasks; ++task) {
|
||||
func(task, /*worker=*/0);
|
||||
}
|
||||
return;
|
||||
|
||||
case ParallelismType::kAcrossClusters:
|
||||
return pools.Pool(pkg_idx).Run(
|
||||
0, num_tasks,
|
||||
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||
|
||||
case ParallelismType::kWithinCluster:
|
||||
return pools.Cluster(pkg_idx, cluster_idx)
|
||||
.Run(0, num_tasks,
|
||||
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||
|
||||
case ParallelismType::kNested:
|
||||
return NestedParallelFor(num_tasks, pools, func);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue