mirror of https://github.com/google/gemma.cpp.git
Add ParallelFor wrapper function and one new mode
Move ParallelismType from matmul.h to threading.h Replace SmallParallelFor with ParallelFor and the new mode PiperOrigin-RevId: 802038452
This commit is contained in:
parent
3737224132
commit
1e3c853e80
|
|
@ -264,6 +264,7 @@ cc_library(
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
":mat",
|
":mat",
|
||||||
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:bit_set",
|
"@highway//:bit_set",
|
||||||
|
|
|
||||||
|
|
@ -233,9 +233,9 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx,
|
||||||
|
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax.ForkJoin");
|
||||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
// Full parallelism is helpful, kAcrossClusters is insufficient.
|
||||||
ParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
NestedParallelFor(num_tokens * div_qbatch.GetDivisor() * layer_config.heads,
|
||||||
ctx.pools, func);
|
ctx.pools, func);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,31 +66,39 @@ void Activation(ActivationType activation, T1* HWY_RESTRICT c1,
|
||||||
|
|
||||||
// No C2 multiplier.
|
// No C2 multiplier.
|
||||||
template <class Mat>
|
template <class Mat>
|
||||||
void ActivationBatched(ActivationType activation, Mat& c1,
|
void ActivationBatched(
|
||||||
ThreadingContext& ctx) {
|
ActivationType activation, Mat& c1, ThreadingContext& ctx,
|
||||||
|
size_t cluster_idx = 0,
|
||||||
|
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
|
||||||
using T = typename Mat::T;
|
using T = typename Mat::T;
|
||||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
||||||
// Cast to correct type so type deduction works.
|
[&](uint64_t task, size_t worker) {
|
||||||
Activation(activation, c1.Row(task), static_cast<const T*>(nullptr),
|
// Cast to correct type so type deduction works.
|
||||||
c1.Cols(), ctx.profiler, worker);
|
Activation(activation, c1.Row(task),
|
||||||
});
|
static_cast<const T*>(nullptr), c1.Cols(),
|
||||||
|
ctx.profiler, worker);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Mat1, class Mat2>
|
template <class Mat1, class Mat2>
|
||||||
HWY_NOINLINE void ActivationBatched(ActivationType activation, Mat1& c1,
|
HWY_NOINLINE void ActivationBatched(
|
||||||
const Mat2* c2, ThreadingContext& ctx) {
|
ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx,
|
||||||
|
size_t cluster_idx = 0,
|
||||||
|
ParallelismType parallelism = ParallelismType::kAcrossClusters) {
|
||||||
HWY_DASSERT(c1.SameShape(*c2));
|
HWY_DASSERT(c1.SameShape(*c2));
|
||||||
if (c2 && c2->HasPtr()) {
|
if (c2 && c2->HasPtr()) {
|
||||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
||||||
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
[&](uint64_t task, size_t worker) {
|
||||||
ctx.profiler, worker);
|
Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(),
|
||||||
});
|
ctx.profiler, worker);
|
||||||
|
});
|
||||||
} else { // No multiplier
|
} else { // No multiplier
|
||||||
SmallParallelFor(c1.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
ParallelFor(parallelism, c1.Rows(), ctx.pools, cluster_idx,
|
||||||
Activation(activation, c1.Row(task),
|
[&](uint64_t task, size_t worker) {
|
||||||
static_cast<const typename Mat2::T*>(nullptr), c1.Cols(),
|
Activation(activation, c1.Row(task),
|
||||||
ctx.profiler, worker);
|
static_cast<const typename Mat2::T*>(nullptr),
|
||||||
});
|
c1.Cols(), ctx.profiler, worker);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1155,12 +1155,13 @@ struct MMImpl {
|
||||||
case ParallelismType::kSequential:
|
case ParallelismType::kSequential:
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||||
range_np)(MMSequentialPolicy(), A, B, C_rows);
|
range_np)(MMSequentialPolicy(), A, B, C_rows);
|
||||||
case ParallelismType::kCluster:
|
case ParallelismType::kWithinCluster:
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||||
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
|
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Parallelism type not implemented.");
|
HWY_ABORT("Parallelism type %s not implemented.",
|
||||||
|
static_cast<int>(options.parallelism_type));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1189,7 +1190,7 @@ struct MMImpl {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C, MMOptions options) {
|
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||||
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[0]);
|
||||||
|
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
|
|
|
||||||
11
ops/matmul.h
11
ops/matmul.h
|
|
@ -27,6 +27,7 @@
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading.h"
|
||||||
#include "util/threading_context.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -56,16 +57,6 @@ static constexpr size_t kMaxMR = 4;
|
||||||
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||||
size_t N, size_t sizeof_TC, size_t nr);
|
size_t N, size_t sizeof_TC, size_t nr);
|
||||||
|
|
||||||
enum class ParallelismType : uint8_t {
|
|
||||||
kNone,
|
|
||||||
// No parallelism.
|
|
||||||
kSequential,
|
|
||||||
// Parallelism at cluster level.
|
|
||||||
kCluster,
|
|
||||||
// Parallelism at package level.
|
|
||||||
kNested,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MMOptions {
|
struct MMOptions {
|
||||||
ParallelismType parallelism_type = ParallelismType::kNested;
|
ParallelismType parallelism_type = ParallelismType::kNested;
|
||||||
uint8_t cluster_idx = 0;
|
uint8_t cluster_idx = 0;
|
||||||
|
|
|
||||||
|
|
@ -494,14 +494,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(const XT* HWY_RESTRICT x,
|
||||||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||||
template <typename XT, typename OT>
|
template <typename XT, typename OT>
|
||||||
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
MatPtrT<OT>& out, ThreadingContext& ctx) {
|
MatPtrT<OT>& out, ThreadingContext& ctx,
|
||||||
|
size_t cluster_idx = 0) {
|
||||||
HWY_DASSERT(weights.Rows() == 1);
|
HWY_DASSERT(weights.Rows() == 1);
|
||||||
HWY_DASSERT(weights.Cols() == activations.Cols());
|
HWY_DASSERT(weights.Cols() == activations.Cols());
|
||||||
HWY_DASSERT(activations.SameShape(out));
|
HWY_DASSERT(activations.SameShape(out));
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
SmallParallelFor(
|
ParallelFor(
|
||||||
activations.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
ParallelismType::kAcrossClusters, activations.Rows(), ctx.pools,
|
||||||
|
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(),
|
||||||
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
out.Row(token_idx), activations.Cols(), ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
@ -510,13 +512,14 @@ void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||||
HWY_DASSERT(weights.Rows() == 1);
|
HWY_DASSERT(weights.Rows() == 1);
|
||||||
HWY_DASSERT(weights.Cols() == inout.Cols());
|
HWY_DASSERT(weights.Cols() == inout.Cols());
|
||||||
|
|
||||||
CallUpcasted(&weights, [&](const auto* weights_t) {
|
CallUpcasted(&weights, [&](const auto* weights_t) {
|
||||||
SmallParallelFor(
|
ParallelFor(
|
||||||
inout.Rows(), ctx.pools, [&](uint64_t token_idx, size_t worker) {
|
ParallelismType::kAcrossClusters, inout.Rows(), ctx.pools, cluster_idx,
|
||||||
|
[&](uint64_t token_idx, size_t worker) {
|
||||||
RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx),
|
RMSNormInplace(weights_t->PackedScale1(), inout.Row(token_idx),
|
||||||
inout.Cols(), ctx.profiler, worker);
|
inout.Cols(), ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
|
|
@ -542,13 +545,14 @@ void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
static HWY_INLINE void AddFromBatched(const MatPtrT<XT>& x, MatPtrT<float>& out,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx,
|
||||||
|
size_t cluster_idx = 0) {
|
||||||
HWY_DASSERT(out.SameShape(x));
|
HWY_DASSERT(out.SameShape(x));
|
||||||
SmallParallelFor(out.Rows(), ctx.pools,
|
ParallelFor(ParallelismType::kAcrossClusters, out.Rows(), ctx.pools,
|
||||||
[&](uint64_t token_idx, size_t worker) {
|
cluster_idx, [&](uint64_t token_idx, size_t worker) {
|
||||||
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
AddFrom(x.Row(token_idx), out.Row(token_idx), x.Cols(),
|
||||||
ctx.profiler, worker);
|
ctx.profiler, worker);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename XT>
|
template <typename XT>
|
||||||
|
|
@ -776,13 +780,15 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap(
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCapBatched(
|
||||||
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
const float cap, MatPtrT<float>& x, const hwy::BitSet4096<>& non_eos,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx, size_t cluster_idx = 0) {
|
||||||
if (cap == 0.0f) return;
|
if (cap == 0.0f) return;
|
||||||
SmallParallelFor(x.Rows(), ctx.pools, [&](uint64_t task, size_t worker) {
|
ParallelFor(ParallelismType::kAcrossClusters, x.Rows(), ctx.pools,
|
||||||
if (non_eos.Get(task)) {
|
cluster_idx, [&](uint64_t task, size_t worker) {
|
||||||
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler, worker);
|
if (non_eos.Get(task)) {
|
||||||
}
|
LogitsSoftCap(cap, x.Row(task), x.Cols(), ctx.profiler,
|
||||||
});
|
worker);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,7 @@ void ParallelizeTwoRanges(const IndexRangePartition& get1,
|
||||||
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
// Calls `func(task, worker)` for each task in `[0, num_tasks)`. Parallelizes
|
||||||
// over clusters of ONE package, then within each cluster.
|
// over clusters of ONE package, then within each cluster.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
void NestedParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||||
// Even if there are multiple packages, we only use the first.
|
// Even if there are multiple packages, we only use the first.
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
|
|
||||||
|
|
@ -356,14 +356,57 @@ void ParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// As above, but for lightweight tasks. Uses only one pool.
|
// Which pool(s) to use for parallelizing:
|
||||||
template <class Func>
|
enum class ParallelismType : uint8_t {
|
||||||
void SmallParallelFor(size_t num_tasks, NestedPools& pools, const Func& func) {
|
// None: single-threaded loop on the calling thread.
|
||||||
// Even if there are multiple packages, we only use the first.
|
kSequential,
|
||||||
const size_t pkg_idx = 0;
|
// One thread per cluster within the first package; or one per core if there
|
||||||
|
// is only one cluster. Use for few or lightweight tasks, or to maximize
|
||||||
|
// memory bandwidth availability.
|
||||||
|
kAcrossClusters,
|
||||||
|
// All cores within the cluster identified by `cluster_idx`. Use if already
|
||||||
|
// within a `kAcrossClusters` parallel-for, or if latency is more important
|
||||||
|
// than memory bandwidth.
|
||||||
|
kWithinCluster,
|
||||||
|
// First statically partitions `kAcrossClusters`, then `kWithinCluster`. This
|
||||||
|
// utilizes all cores, but has higher fork-join overhead (two barriers); use
|
||||||
|
// if there are many or heavy tasks.
|
||||||
|
kNested,
|
||||||
|
};
|
||||||
|
|
||||||
pools.Pool(pkg_idx).Run(
|
// Calls `func(task, worker)` for each `task` in `[0, num_tasks)`, with the
|
||||||
0, num_tasks, [&](uint64_t task, size_t thread) { func(task, thread); });
|
// number/type of workers determined by `parallelism`. `cluster_idx` is only
|
||||||
|
// used if `parallelism == kWithinCluster`.
|
||||||
|
template <class Func>
|
||||||
|
void ParallelFor(ParallelismType parallelism, size_t num_tasks,
|
||||||
|
NestedPools& pools, size_t cluster_idx, const Func& func) {
|
||||||
|
if (cluster_idx != 0) {
|
||||||
|
// If already running across clusters, must not use across-cluster modes.
|
||||||
|
HWY_DASSERT(parallelism != ParallelismType::kAcrossClusters &&
|
||||||
|
parallelism != ParallelismType::kNested);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t pkg_idx = 0;
|
||||||
|
switch (parallelism) {
|
||||||
|
case ParallelismType::kSequential:
|
||||||
|
for (size_t task = 0; task < num_tasks; ++task) {
|
||||||
|
func(task, /*worker=*/0);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
|
||||||
|
case ParallelismType::kAcrossClusters:
|
||||||
|
return pools.Pool(pkg_idx).Run(
|
||||||
|
0, num_tasks,
|
||||||
|
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||||
|
|
||||||
|
case ParallelismType::kWithinCluster:
|
||||||
|
return pools.Cluster(pkg_idx, cluster_idx)
|
||||||
|
.Run(0, num_tasks,
|
||||||
|
[&](uint64_t task, size_t worker) { func(task, worker); });
|
||||||
|
|
||||||
|
case ParallelismType::kNested:
|
||||||
|
return NestedParallelFor(num_tasks, pools, func);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue