Include parallelism type in DoMatMul. Also remove package handling.

PiperOrigin-RevId: 800902568
This commit is contained in:
Marie White 2025-08-29 08:04:05 -07:00 committed by Copybara-Service
parent 0ae8646731
commit 00b70f69c5
1 changed files with 19 additions and 22 deletions

View File

@ -1135,28 +1135,23 @@ struct MMImpl {
template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args,
const MMConfig& config) {
const MMConfig& config,
ParallelismType parallelism_type) {
PROFILER_ZONE("MM.DoMatMul");
static const auto zone =
args.env->ctx.profiler.AddZone("MM.DoMatMul.PerPkg");
const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
if constexpr (kMaxPackages > 1) {
// Outermost loop: static NUMA-aware partition of B rows across packages.
MMNestedParallelPolicy::ForPkg(
args.env->ctx, args.per_key->ranges_np.NumTasks(),
[&](size_t pkg_idx) {
MMZone mm_zone;
mm_zone.MaybeEnter(pkg_idx, zone, args);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
});
} else {
const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
switch (parallelism_type) {
case ParallelismType::kNested:
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
MMNestedParallelPolicy(), A, B, C_rows);
break;
case ParallelismType::kNone:
case ParallelismType::kSequential:
case ParallelismType::kCluster:
HWY_ABORT("Parallelism type not implemented.");
break;
}
}
};
@ -1210,10 +1205,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune;
// Default to nested parallelism.
const ParallelismType parallelism_type = ParallelismType::kNested;
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add);
if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type);
return &per_key;
}
@ -1242,7 +1239,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, C_rows, args, cfg);
MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type);
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /