mirror of https://github.com/google/gemma.cpp.git
Include parallelism type in DoMatMul. Also remove package handling.
PiperOrigin-RevId: 800902568
This commit is contained in:
parent
0ae8646731
commit
00b70f69c5
|
|
@ -1135,28 +1135,23 @@ struct MMImpl {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||||
const MMConfig& config) {
|
const MMConfig& config,
|
||||||
|
ParallelismType parallelism_type) {
|
||||||
PROFILER_ZONE("MM.DoMatMul");
|
PROFILER_ZONE("MM.DoMatMul");
|
||||||
static const auto zone =
|
|
||||||
args.env->ctx.profiler.AddZone("MM.DoMatMul.PerPkg");
|
|
||||||
|
|
||||||
if constexpr (kMaxPackages > 1) {
|
|
||||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
|
||||||
MMNestedParallelPolicy::ForPkg(
|
|
||||||
args.env->ctx, args.per_key->ranges_np.NumTasks(),
|
|
||||||
[&](size_t pkg_idx) {
|
|
||||||
MMZone mm_zone;
|
|
||||||
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
|
||||||
MMNestedParallelPolicy(), A, B, C_rows);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
|
|
||||||
|
switch (parallelism_type) {
|
||||||
|
case ParallelismType::kNested:
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||||
MMNestedParallelPolicy(), A, B, C_rows);
|
MMNestedParallelPolicy(), A, B, C_rows);
|
||||||
|
break;
|
||||||
|
case ParallelismType::kNone:
|
||||||
|
case ParallelismType::kSequential:
|
||||||
|
case ParallelismType::kCluster:
|
||||||
|
HWY_ABORT("Parallelism type not implemented.");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1210,10 +1205,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
MMPerKey& per_key = env.per_key[index];
|
MMPerKey& per_key = env.per_key[index];
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
|
||||||
|
// Default to nested parallelism.
|
||||||
|
const ParallelismType parallelism_type = ParallelismType::kNested;
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||||
add);
|
add);
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
|
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type);
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1242,7 +1239,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg);
|
MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue