mirror of https://github.com/google/gemma.cpp.git
Include parallelism type in DoMatMul. Also remove package handling.
PiperOrigin-RevId: 800902568
This commit is contained in:
parent
0ae8646731
commit
00b70f69c5
|
|
@ -1135,28 +1135,23 @@ struct MMImpl {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows, const MMArgs& args,
|
||||
const MMConfig& config) {
|
||||
const MMConfig& config,
|
||||
ParallelismType parallelism_type) {
|
||||
PROFILER_ZONE("MM.DoMatMul");
|
||||
static const auto zone =
|
||||
args.env->ctx.profiler.AddZone("MM.DoMatMul.PerPkg");
|
||||
const size_t pkg_idx = 0;
|
||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
|
||||
if constexpr (kMaxPackages > 1) {
|
||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||
MMNestedParallelPolicy::ForPkg(
|
||||
args.env->ctx, args.per_key->ranges_np.NumTasks(),
|
||||
[&](size_t pkg_idx) {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMNestedParallelPolicy(), A, B, C_rows);
|
||||
});
|
||||
} else {
|
||||
const size_t pkg_idx = 0;
|
||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMNestedParallelPolicy(), A, B, C_rows);
|
||||
switch (parallelism_type) {
|
||||
case ParallelismType::kNested:
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMNestedParallelPolicy(), A, B, C_rows);
|
||||
break;
|
||||
case ParallelismType::kNone:
|
||||
case ParallelismType::kSequential:
|
||||
case ParallelismType::kCluster:
|
||||
HWY_ABORT("Parallelism type not implemented.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -1210,10 +1205,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
MMPerKey& per_key = env.per_key[index];
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
|
||||
// Default to nested parallelism.
|
||||
const ParallelismType parallelism_type = ParallelismType::kNested;
|
||||
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
|
||||
add);
|
||||
if (HWY_LIKELY(tuner.Best())) {
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best(), parallelism_type);
|
||||
return &per_key;
|
||||
}
|
||||
|
||||
|
|
@ -1242,7 +1239,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg);
|
||||
MMImpl::DoMatMul(A, B, C_rows, args, cfg, parallelism_type);
|
||||
const uint64_t t1 =
|
||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||
|
|
|
|||
Loading…
Reference in New Issue