mirror of https://github.com/google/gemma.cpp.git
Add in-cluster parallel policy. Update policy to include cluster_idx.
PiperOrigin-RevId: 802016308
This commit is contained in:
parent
27cb8e12d9
commit
3737224132
|
|
@ -728,9 +728,10 @@ class MMPerPackage {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
|
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
|
||||||
size_t pkg_idx, const IndexRange& range_np)
|
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np)
|
||||||
: args_(args),
|
: args_(args),
|
||||||
pkg_idx_(pkg_idx),
|
pkg_idx_(pkg_idx),
|
||||||
|
cluster_idx_(cluster_idx),
|
||||||
range_np_(range_np),
|
range_np_(range_np),
|
||||||
mr_(config.MR()),
|
mr_(config.MR()),
|
||||||
ranges_mc_(config.RangesOfMC(A.rows)),
|
ranges_mc_(config.RangesOfMC(A.rows)),
|
||||||
|
|
@ -821,7 +822,8 @@ class MMPerPackage {
|
||||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||||
MMParallelPolicyT::ForNP(
|
MMParallelPolicyT::ForNP(
|
||||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||||
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
pkg_idx_, cluster_idx_,
|
||||||
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
|
|
@ -869,7 +871,8 @@ class MMPerPackage {
|
||||||
|
|
||||||
MMParallelPolicyT::ForNP(
|
MMParallelPolicyT::ForNP(
|
||||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||||
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
pkg_idx_, cluster_idx_,
|
||||||
|
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
|
|
@ -901,7 +904,7 @@ class MMPerPackage {
|
||||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||||
// except for the profiler strings and `out_tag`.
|
// except for the profiler strings and `out_tag`.
|
||||||
MMParallelPolicyT::ForRangesMC_NC(
|
MMParallelPolicyT::ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
|
|
@ -951,7 +954,7 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
}; // loop_nc
|
}; // loop_nc
|
||||||
MMParallelPolicyT::ForRangesMC_NC(
|
MMParallelPolicyT::ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
|
|
@ -1024,7 +1027,7 @@ class MMPerPackage {
|
||||||
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
||||||
|
|
||||||
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
|
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
|
||||||
pkg_idx_,
|
pkg_idx_, cluster_idx_,
|
||||||
[&](const IndexRange& range_K, size_t worker) {
|
[&](const IndexRange& range_K, size_t worker) {
|
||||||
do_range(all_M, range_K, worker);
|
do_range(all_M, range_K, worker);
|
||||||
});
|
});
|
||||||
|
|
@ -1032,7 +1035,8 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
case MMParA::kM:
|
case MMParA::kM:
|
||||||
MMParallelPolicyT::ForRangeMC(
|
MMParallelPolicyT::ForRangeMC(
|
||||||
args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
args_.env->ctx, all_M, pkg_idx_, cluster_idx_,
|
||||||
|
[&](size_t row_a, size_t worker) {
|
||||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
|
|
@ -1106,6 +1110,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
const MMArgs args_; // copy for locality
|
const MMArgs args_; // copy for locality
|
||||||
const size_t pkg_idx_;
|
const size_t pkg_idx_;
|
||||||
|
const size_t cluster_idx_; // 0 for sequential and nested.
|
||||||
|
|
||||||
const IndexRange range_np_;
|
const IndexRange range_np_;
|
||||||
// From MMConfig:
|
// From MMConfig:
|
||||||
|
|
@ -1135,23 +1140,26 @@ struct MMImpl {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||||
const MMConfig& config,
|
const MMConfig& config, MMOptions options) {
|
||||||
ParallelismType parallelism_type) {
|
|
||||||
PROFILER_ZONE("MM.DoMatMul");
|
PROFILER_ZONE("MM.DoMatMul");
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
|
|
||||||
switch (parallelism_type) {
|
switch (options.parallelism_type) {
|
||||||
case ParallelismType::kNested:
|
case ParallelismType::kNested:
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
HWY_DASSERT(options.cluster_idx == 0);
|
||||||
MMNestedParallelPolicy(), A, B, C_rows);
|
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||||
|
range_np)(MMNestedParallelPolicy(), A, B, C_rows);
|
||||||
break;
|
break;
|
||||||
case ParallelismType::kSequential:
|
case ParallelismType::kSequential:
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||||
MMSequentialPolicy(), A, B, C_rows);
|
range_np)(MMSequentialPolicy(), A, B, C_rows);
|
||||||
case ParallelismType::kNone:
|
|
||||||
case ParallelismType::kCluster:
|
case ParallelismType::kCluster:
|
||||||
|
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||||
|
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
HWY_ABORT("Parallelism type not implemented.");
|
HWY_ABORT("Parallelism type not implemented.");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -1210,8 +1218,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
add);
|
add);
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(),
|
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options);
|
||||||
options.parallelism_type);
|
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1240,7 +1247,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type);
|
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
|
||||||
80
ops/matmul.h
80
ops/matmul.h
|
|
@ -81,9 +81,10 @@ struct MMSequentialPolicy {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||||
const Func& func) {
|
size_t cluster_idx, const Func& func) {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
||||||
|
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
func(range_np, base_idx);
|
func(range_np, base_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -91,8 +92,10 @@ struct MMSequentialPolicy {
|
||||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
const IndexRangePartition& ranges_mc,
|
const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc,
|
const IndexRangePartition& ranges_nc,
|
||||||
size_t pkg_idx, const Func& func) {
|
size_t pkg_idx, size_t cluster_idx,
|
||||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const Func& func) {
|
||||||
|
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
||||||
|
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
|
|
||||||
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
||||||
const IndexRange range_mc = ranges_mc.Range(i);
|
const IndexRange range_mc = ranges_mc.Range(i);
|
||||||
|
|
@ -105,14 +108,68 @@ struct MMSequentialPolicy {
|
||||||
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
size_t pkg_idx, const Func& func) {
|
size_t pkg_idx, size_t cluster_idx, const Func& func) {
|
||||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
||||||
|
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
||||||
func(row_a, base_idx);
|
func(row_a, base_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MMClusterParallelPolicy {
|
||||||
|
template <class Func>
|
||||||
|
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||||
|
const Func& func) {
|
||||||
|
func(/*pkg_idx=*/0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Func>
|
||||||
|
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
|
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||||
|
size_t cluster_idx, const Func& func) {
|
||||||
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
|
|
||||||
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
|
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
|
ParallelizeOneRange(worker_ranges, cluster,
|
||||||
|
[&](const IndexRange& worker_range, size_t worker) {
|
||||||
|
func(worker_range, worker);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Func>
|
||||||
|
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
|
const IndexRangePartition& ranges_mc,
|
||||||
|
const IndexRangePartition& ranges_nc,
|
||||||
|
size_t pkg_idx, size_t cluster_idx,
|
||||||
|
const Func& func) {
|
||||||
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
|
|
||||||
|
// Low-batch: avoid Divide/Remainder.
|
||||||
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
|
ParallelizeOneRange(ranges_nc, cluster,
|
||||||
|
[&](const IndexRange& range_nc, size_t thread) {
|
||||||
|
func(ranges_mc.Range(0), range_nc, thread);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
ParallelizeTwoRanges(
|
||||||
|
ranges_mc, ranges_nc, cluster,
|
||||||
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
|
size_t thread) { func(range_mc, range_nc, thread); });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Func>
|
||||||
|
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
|
size_t pkg_idx, size_t cluster_idx, const Func& func) {
|
||||||
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
|
cluster.Run(range_mc.begin(), range_mc.end(),
|
||||||
|
[&](uint64_t row_a, size_t thread) { func(row_a, thread); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct MMNestedParallelPolicy {
|
struct MMNestedParallelPolicy {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||||
|
|
@ -132,10 +189,11 @@ struct MMNestedParallelPolicy {
|
||||||
|
|
||||||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||||
|
// `cluster_idx` is not used here as all clusters within a package are used.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||||
const Func& func) {
|
size_t /*cluster_idx*/, const Func& func) {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
|
|
||||||
|
|
@ -175,11 +233,13 @@ struct MMNestedParallelPolicy {
|
||||||
|
|
||||||
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||||
// rows). Calls `func(range_mc, range_nc, worker)`.
|
// rows). Calls `func(range_mc, range_nc, worker)`.
|
||||||
|
// `cluster_idx` is not used here as all clusters within a package are used.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
const IndexRangePartition& ranges_mc,
|
const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc,
|
const IndexRangePartition& ranges_nc,
|
||||||
size_t pkg_idx, const Func& func) {
|
size_t pkg_idx, size_t /*cluster_idx*/,
|
||||||
|
const Func& func) {
|
||||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||||
|
|
@ -221,9 +281,11 @@ struct MMNestedParallelPolicy {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls `func(row_a, worker)` in parallel.
|
// Calls `func(row_a, worker)` in parallel.
|
||||||
|
// `cluster_idx` is not used here as all clusters within a package are used.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
size_t pkg_idx, const Func& func) {
|
size_t pkg_idx, size_t /*cluster_idx*/,
|
||||||
|
const Func& func) {
|
||||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
ctx.pools.Pool(pkg_idx).Run(
|
ctx.pools.Pool(pkg_idx).Run(
|
||||||
range_mc.begin(), range_mc.end(),
|
range_mc.begin(), range_mc.end(),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue