mirror of https://github.com/google/gemma.cpp.git
Add in-cluster parallel policy. Update policy to include cluster_idx.
PiperOrigin-RevId: 802016308
This commit is contained in:
parent
27cb8e12d9
commit
3737224132
|
|
@ -728,9 +728,10 @@ class MMPerPackage {
|
|||
|
||||
public:
|
||||
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
|
||||
size_t pkg_idx, const IndexRange& range_np)
|
||||
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np)
|
||||
: args_(args),
|
||||
pkg_idx_(pkg_idx),
|
||||
cluster_idx_(cluster_idx),
|
||||
range_np_(range_np),
|
||||
mr_(config.MR()),
|
||||
ranges_mc_(config.RangesOfMC(A.rows)),
|
||||
|
|
@ -821,7 +822,8 @@ class MMPerPackage {
|
|||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||
MMParallelPolicyT::ForNP(
|
||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
pkg_idx_, cluster_idx_,
|
||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args_);
|
||||
|
||||
|
|
@ -869,7 +871,8 @@ class MMPerPackage {
|
|||
|
||||
MMParallelPolicyT::ForNP(
|
||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
pkg_idx_, cluster_idx_,
|
||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args_);
|
||||
|
||||
|
|
@ -901,7 +904,7 @@ class MMPerPackage {
|
|||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||
// except for the profiler strings and `out_tag`.
|
||||
MMParallelPolicyT::ForRangesMC_NC(
|
||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
|
|
@ -951,7 +954,7 @@ class MMPerPackage {
|
|||
}
|
||||
}; // loop_nc
|
||||
MMParallelPolicyT::ForRangesMC_NC(
|
||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_, cluster_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
|
|
@ -1024,7 +1027,7 @@ class MMPerPackage {
|
|||
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
||||
|
||||
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
|
||||
pkg_idx_,
|
||||
pkg_idx_, cluster_idx_,
|
||||
[&](const IndexRange& range_K, size_t worker) {
|
||||
do_range(all_M, range_K, worker);
|
||||
});
|
||||
|
|
@ -1032,7 +1035,8 @@ class MMPerPackage {
|
|||
}
|
||||
case MMParA::kM:
|
||||
MMParallelPolicyT::ForRangeMC(
|
||||
args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
||||
args_.env->ctx, all_M, pkg_idx_, cluster_idx_,
|
||||
[&](size_t row_a, size_t worker) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
||||
});
|
||||
break;
|
||||
|
|
@ -1106,6 +1110,7 @@ class MMPerPackage {
|
|||
|
||||
const MMArgs args_; // copy for locality
|
||||
const size_t pkg_idx_;
|
||||
const size_t cluster_idx_; // 0 for sequential and nested.
|
||||
|
||||
const IndexRange range_np_;
|
||||
// From MMConfig:
|
||||
|
|
@ -1135,23 +1140,26 @@ struct MMImpl {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||
const MMConfig& config,
|
||||
ParallelismType parallelism_type) {
|
||||
const MMConfig& config, MMOptions options) {
|
||||
PROFILER_ZONE("MM.DoMatMul");
|
||||
const size_t pkg_idx = 0;
|
||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
|
||||
switch (parallelism_type) {
|
||||
switch (options.parallelism_type) {
|
||||
case ParallelismType::kNested:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMNestedParallelPolicy(), A, B, C_rows);
|
||||
HWY_DASSERT(options.cluster_idx == 0);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||
range_np)(MMNestedParallelPolicy(), A, B, C_rows);
|
||||
break;
|
||||
case ParallelismType::kSequential:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMSequentialPolicy(), A, B, C_rows);
|
||||
case ParallelismType::kNone:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||
range_np)(MMSequentialPolicy(), A, B, C_rows);
|
||||
case ParallelismType::kCluster:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, options.cluster_idx,
|
||||
range_np)(MMClusterParallelPolicy(), A, B, C_rows);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Parallelism type not implemented.");
|
||||
break;
|
||||
}
|
||||
|
|
@ -1210,8 +1218,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add);
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(),
|
||||
options.parallelism_type);
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), options);
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
|
|
@ -1240,7 +1247,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options.parallelism_type);
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, options);
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||
|
|
|
|||
80
ops/matmul.h
80
ops/matmul.h
|
|
@ -81,9 +81,10 @@ struct MMSequentialPolicy {
|
|||
template <class Func>
|
||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
size_t cluster_idx, const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
func(range_np, base_idx);
|
||||
}
|
||||
|
||||
|
|
@ -91,8 +92,10 @@ struct MMSequentialPolicy {
|
|||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
size_t pkg_idx, size_t cluster_idx,
|
||||
const Func& func) {
|
||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
|
||||
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
|
||||
const IndexRange range_mc = ranges_mc.Range(i);
|
||||
|
|
@ -105,14 +108,68 @@ struct MMSequentialPolicy {
|
|||
|
||||
template <class Func>
|
||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
size_t pkg_idx, size_t cluster_idx, const Func& func) {
|
||||
const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage() +
|
||||
cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
|
||||
func(row_a, base_idx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MMClusterParallelPolicy {
|
||||
template <class Func>
|
||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||
const Func& func) {
|
||||
func(/*pkg_idx=*/0);
|
||||
}
|
||||
|
||||
template <class Func>
|
||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||
size_t cluster_idx, const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
ParallelizeOneRange(worker_ranges, cluster,
|
||||
[&](const IndexRange& worker_range, size_t worker) {
|
||||
func(worker_range, worker);
|
||||
});
|
||||
}
|
||||
|
||||
template <class Func>
|
||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc,
|
||||
size_t pkg_idx, size_t cluster_idx,
|
||||
const Func& func) {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
|
||||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
ParallelizeOneRange(ranges_nc, cluster,
|
||||
[&](const IndexRange& range_nc, size_t thread) {
|
||||
func(ranges_mc.Range(0), range_nc, thread);
|
||||
});
|
||||
} else {
|
||||
ParallelizeTwoRanges(
|
||||
ranges_mc, ranges_nc, cluster,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t thread) { func(range_mc, range_nc, thread); });
|
||||
}
|
||||
}
|
||||
|
||||
template <class Func>
|
||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t pkg_idx, size_t cluster_idx, const Func& func) {
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
cluster.Run(range_mc.begin(), range_mc.end(),
|
||||
[&](uint64_t row_a, size_t thread) { func(row_a, thread); });
|
||||
}
|
||||
};
|
||||
|
||||
struct MMNestedParallelPolicy {
|
||||
template <class Func>
|
||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||
|
|
@ -132,10 +189,11 @@ struct MMNestedParallelPolicy {
|
|||
|
||||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||
// `cluster_idx` is not used here as all clusters within a package are used.
|
||||
template <class Func>
|
||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
size_t /*cluster_idx*/, const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
|
||||
|
|
@ -175,11 +233,13 @@ struct MMNestedParallelPolicy {
|
|||
|
||||
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||
// rows). Calls `func(range_mc, range_nc, worker)`.
|
||||
// `cluster_idx` is not used here as all clusters within a package are used.
|
||||
template <class Func>
|
||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
size_t pkg_idx, size_t /*cluster_idx*/,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||
|
|
@ -221,9 +281,11 @@ struct MMNestedParallelPolicy {
|
|||
}
|
||||
|
||||
// Calls `func(row_a, worker)` in parallel.
|
||||
// `cluster_idx` is not used here as all clusters within a package are used.
|
||||
template <class Func>
|
||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
size_t pkg_idx, size_t /*cluster_idx*/,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
ctx.pools.Pool(pkg_idx).Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue