mirror of https://github.com/google/gemma.cpp.git
Refactor Matmul to use a policy class for parallelization.
PiperOrigin-RevId: 800864489
This commit is contained in:
parent
6c39a2dea4
commit
973e284ed6
|
|
@ -320,8 +320,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
|||
const size_t start = owners.size();
|
||||
owners.resize(start + tensors.size());
|
||||
|
||||
MMParallel parallel(ctx);
|
||||
|
||||
// Allocate in parallel because faulting in large tensors is slow.
|
||||
ctx.pools.Pool().Run(
|
||||
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||
|
|
@ -339,7 +337,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
|||
|
||||
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
|
||||
tensor.padding);
|
||||
BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel);
|
||||
BindB(ctx, *tensor.mat, tensor.mat->ElementBytes());
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
102
ops/matmul-inl.h
102
ops/matmul-inl.h
|
|
@ -779,17 +779,19 @@ class MMPerPackage {
|
|||
// B and maybe A are decompressed several call layers lower, but not all
|
||||
// member functions depend on TA/TB, so pass them as an argument instead of
|
||||
// templating the class.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
|
||||
const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
if constexpr (WantDecompressA<TA>()) {
|
||||
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
||||
DecompressA(A, A_view);
|
||||
DecompressA<MMParallelPolicyT>(A, A_view);
|
||||
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
|
||||
DispatchOrder(A_view, A_padded, B, C_rows);
|
||||
DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
|
||||
} else {
|
||||
const bool A_padded = HasPadding(A);
|
||||
DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows);
|
||||
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
|
||||
C_rows);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -828,28 +830,30 @@ class MMPerPackage {
|
|||
}
|
||||
|
||||
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DispatchOrder(const StridedView<TA> A, const bool A_padded,
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
|
||||
const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT(A, A_padded, B, C_rows);
|
||||
return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
case MMOrder::kNT_K:
|
||||
return DoNT_K(A, A_padded, B, C_rows);
|
||||
return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
case MMOrder::kNT_MT:
|
||||
return DoNT_MT(A, A_padded, B, C_rows);
|
||||
return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
case MMOrder::kNT_MT_K:
|
||||
return DoNT_MT_K(A, A_padded, B, C_rows);
|
||||
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
}
|
||||
|
||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
|
|
@ -861,9 +865,9 @@ class MMPerPackage {
|
|||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
||||
|
||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||
args_.env->parallel.ForNP(
|
||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMParallelPolicyT::ForNP(
|
||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args_);
|
||||
|
||||
|
|
@ -881,9 +885,10 @@ class MMPerPackage {
|
|||
}
|
||||
|
||||
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
||||
|
|
@ -909,9 +914,9 @@ class MMPerPackage {
|
|||
}
|
||||
};
|
||||
|
||||
args_.env->parallel.ForNP(
|
||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMParallelPolicyT::ForNP(
|
||||
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args_);
|
||||
|
||||
|
|
@ -930,9 +935,10 @@ class MMPerPackage {
|
|||
|
||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
|
|
@ -942,8 +948,8 @@ class MMPerPackage {
|
|||
|
||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||
// except for the profiler strings and `out_tag`.
|
||||
args_.env->parallel.ForRangesMC_NC(
|
||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
MMParallelPolicyT::ForRangesMC_NC(
|
||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
|
|
@ -965,9 +971,10 @@ class MMPerPackage {
|
|||
|
||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||
|
|
@ -992,8 +999,8 @@ class MMPerPackage {
|
|||
out_tag, args_, C_rows);
|
||||
}
|
||||
}; // loop_nc
|
||||
args_.env->parallel.ForRangesMC_NC(
|
||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
MMParallelPolicyT::ForRangesMC_NC(
|
||||
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||
size_t worker) HWY_ATTR {
|
||||
MMZone mm_zone;
|
||||
|
|
@ -1014,6 +1021,7 @@ class MMPerPackage {
|
|||
}
|
||||
|
||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||
template <typename MMParallelPolicyT>
|
||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view,
|
||||
MMParA par_a) const {
|
||||
|
|
@ -1064,16 +1072,16 @@ class MMPerPackage {
|
|||
// line to avoid false sharing.
|
||||
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
||||
|
||||
args_.env->parallel.ForNP(
|
||||
all_K, multiple_K, inner_tasks, pkg_idx_,
|
||||
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
|
||||
pkg_idx_,
|
||||
[&](const IndexRange& range_K, size_t worker) {
|
||||
do_range(all_M, range_K, worker);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case MMParA::kM:
|
||||
args_.env->parallel.ForRangeMC(
|
||||
all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
||||
MMParallelPolicyT::ForRangeMC(
|
||||
args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
||||
});
|
||||
break;
|
||||
|
|
@ -1081,12 +1089,13 @@ class MMPerPackage {
|
|||
}
|
||||
|
||||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename MMParallelPolicyT>
|
||||
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view) const {
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
return DoDecompressA(A, A_view, *autotune.Best());
|
||||
return DoDecompressA<MMParallelPolicyT>(A, A_view, *autotune.Best());
|
||||
}
|
||||
|
||||
// First call: generate candidates.
|
||||
|
|
@ -1099,7 +1108,7 @@ class MMPerPackage {
|
|||
|
||||
const MMParA& par_a = autotune.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
DoDecompressA(A, A_view, par_a);
|
||||
DoDecompressA<MMParallelPolicyT>(A, A_view, par_a);
|
||||
const uint64_t t1 =
|
||||
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||
|
|
@ -1185,19 +1194,21 @@ struct MMImpl {
|
|||
|
||||
if constexpr (kMaxPackages > 1) {
|
||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||
args.env->parallel.ForPkg(
|
||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||
MMNestedParallelPolicy::ForPkg(
|
||||
args.env->ctx, args.per_key->ranges_np.NumTasks(),
|
||||
[&](size_t pkg_idx) {
|
||||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B,
|
||||
C_rows);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMNestedParallelPolicy(), A, B, C_rows);
|
||||
});
|
||||
} else {
|
||||
const size_t pkg_idx = 0;
|
||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||
MMNestedParallelPolicy(), A, B, C_rows);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -1250,8 +1261,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
|
||||
// invalidates `MMAutoTune::Best()`
|
||||
index = env.per_key.size();
|
||||
env.per_key.push_back(
|
||||
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
|
||||
env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
|
||||
}
|
||||
MMPerKey& per_key = env.per_key[index];
|
||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||
|
|
|
|||
|
|
@ -397,34 +397,33 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
|
|||
return np_multiple;
|
||||
}
|
||||
|
||||
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
||||
size_t sizeof_TC, size_t nr) const {
|
||||
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
|
||||
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||
size_t N, size_t sizeof_TC, size_t nr) {
|
||||
const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages());
|
||||
return StaticPartition(
|
||||
IndexRange(0, N), num_packages,
|
||||
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
|
||||
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
|
||||
}
|
||||
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
|
||||
char cpu100[100];
|
||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
|
||||
}
|
||||
|
||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||
Allocator& allocator = parallel.allocator();
|
||||
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
|
||||
Allocator& allocator = ctx.allocator;
|
||||
if (!allocator.ShouldBind()) return;
|
||||
if (B.Rows() == 1) return;
|
||||
|
||||
PROFILER_ZONE("Startup.BindB");
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||
MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
|
||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
||||
// B row padding is less than the page size, so only bind the subset that
|
||||
|
|
@ -438,14 +437,14 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
|||
}
|
||||
|
||||
// C is BF16/float
|
||||
void BindC(MatPtr& C, MMParallel& parallel) {
|
||||
Allocator& allocator = parallel.allocator();
|
||||
void BindC(ThreadingContext& ctx, MatPtr& C) {
|
||||
Allocator& allocator = ctx.allocator;
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
PROFILER_ZONE("Startup.BindC");
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
||||
MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
||||
bool ok = true;
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||
|
|
@ -455,7 +454,7 @@ void BindC(MatPtr& C, MMParallel& parallel) {
|
|||
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
|
||||
allocator.BasePageBytes());
|
||||
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||
for (size_t im = 0; im < C.Rows(); ++im) {
|
||||
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
|
||||
}
|
||||
|
|
|
|||
107
ops/matmul.h
107
ops/matmul.h
|
|
@ -53,35 +53,31 @@ constexpr size_t kNR = 4;
|
|||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||
static constexpr size_t kMaxMR = 4;
|
||||
|
||||
// Mostly stateless, can be constructed on the fly by weights.cc. Captures the
|
||||
// the ThreadingContext to shorten call sites.
|
||||
class MMParallel {
|
||||
public:
|
||||
// `ctx` must outlive this object.
|
||||
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
|
||||
if (ctx_.pools.NumPackages() > kMaxPackages) {
|
||||
HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.",
|
||||
ctx_.pools.NumPackages(), kMaxPackages);
|
||||
}
|
||||
}
|
||||
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||
size_t N, size_t sizeof_TC, size_t nr);
|
||||
|
||||
Allocator& allocator() const { return ctx_.allocator; }
|
||||
enum class ParallelismType : uint8_t {
|
||||
kNone,
|
||||
// No parallelism.
|
||||
kSequential,
|
||||
// Parallelism at cluster level.
|
||||
kCluster,
|
||||
// Parallelism at package level.
|
||||
kNested,
|
||||
};
|
||||
|
||||
// Initial static partitioning of B rows across packages.
|
||||
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
||||
size_t sizeof_TC, size_t nr) const;
|
||||
struct MMOptions {
|
||||
ParallelismType parallelism_type_ = ParallelismType::kNested;
|
||||
uint8_t cluster_idx_ = 0;
|
||||
};
|
||||
|
||||
// For `BindB` and `BindC`.
|
||||
size_t Node(size_t pkg_idx) const {
|
||||
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
|
||||
}
|
||||
|
||||
// Calls `func(pkg_idx)` for each package in parallel.
|
||||
struct MMNestedParallelPolicy {
|
||||
template <class Func>
|
||||
void ForPkg(const size_t max_packages, const Func& func) {
|
||||
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||
const Func& func) {
|
||||
if constexpr (kMaxPackages > 1) {
|
||||
ctx_.pools.AllPackages().Run(
|
||||
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
|
||||
ctx.pools.AllPackages().Run(
|
||||
0, HWY_MIN(max_packages, ctx.pools.NumPackages()),
|
||||
[&](uint64_t task, size_t pkg_idx) {
|
||||
HWY_DASSERT(task == pkg_idx);
|
||||
(void)task;
|
||||
|
|
@ -95,16 +91,17 @@ class MMParallel {
|
|||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||
template <class Func>
|
||||
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
|
||||
// Single cluster: parallel-for over static partition of `range_np`.
|
||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
if (num_clusters == 1) {
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0);
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
return ParallelizeOneRange(
|
||||
|
|
@ -120,9 +117,9 @@ class MMParallel {
|
|||
ParallelizeOneRange(
|
||||
nx_ranges, all_clusters,
|
||||
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
const size_t cluster_base =
|
||||
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
|
||||
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
|
|
@ -137,18 +134,19 @@ class MMParallel {
|
|||
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||
// rows). Calls `func(range_mc, range_nc, worker)`.
|
||||
template <class Func>
|
||||
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||
const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
// Single (big) cluster: collapse two range indices into one parallel-for
|
||||
// to reduce the number of fork-joins.
|
||||
if (num_clusters == 1) {
|
||||
const size_t cluster_idx = 0;
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
return ParallelizeOneRange(
|
||||
|
|
@ -171,8 +169,8 @@ class MMParallel {
|
|||
ranges_nc, all_clusters,
|
||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||
const size_t cluster_base =
|
||||
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||
ParallelizeOneRange(ranges_mc, cluster,
|
||||
[&](const IndexRange& range_mc, size_t thread) {
|
||||
func(range_mc, range_nc, cluster_base + thread);
|
||||
|
|
@ -182,21 +180,18 @@ class MMParallel {
|
|||
|
||||
// Calls `func(row_a, worker)` in parallel.
|
||||
template <class Func>
|
||||
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
||||
ctx_.pools.Pool(pkg_idx).Run(
|
||||
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||
size_t pkg_idx, const Func& func) {
|
||||
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||
ctx.pools.Pool(pkg_idx).Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
|
||||
}
|
||||
|
||||
private:
|
||||
ThreadingContext& ctx_;
|
||||
};
|
||||
|
||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
||||
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
||||
// C is BF16/float.
|
||||
void BindC(MatPtr& C, MMParallel& parallel);
|
||||
void BindC(ThreadingContext& ctx, MatPtr& C);
|
||||
|
||||
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
|
|
@ -250,15 +245,18 @@ class MMStorage {
|
|||
|
||||
// Internally threaded; must not be called concurrently with the same
|
||||
// `ThreadingContext` (used via `parallel`).
|
||||
MMStorage(const Allocator& allocator, MMParallel& parallel) {
|
||||
MMStorage(ThreadingContext& ctx) {
|
||||
// Per-package allocation so each can decompress A into its own copy.
|
||||
// Must be padded, see `DoDecompressA`.
|
||||
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) {
|
||||
// Default to nested parallel policy.
|
||||
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
|
||||
Allocator& allocator = ctx.allocator;
|
||||
|
||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
||||
|
||||
if (allocator.ShouldBind()) {
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
||||
pkg_A_[pkg_idx]->ElementBytes();
|
||||
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
|
||||
|
|
@ -607,9 +605,9 @@ class MMKeys {
|
|||
|
||||
// Per-MatMul-shape state.
|
||||
struct MMPerKey {
|
||||
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
|
||||
MMParallel& parallel)
|
||||
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {
|
||||
MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N,
|
||||
size_t sizeof_TC, size_t nr)
|
||||
: ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) {
|
||||
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
|
||||
}
|
||||
|
||||
|
|
@ -639,7 +637,6 @@ struct MatMulEnv {
|
|||
// Whether to print the best config immediately after autotuning finished.
|
||||
bool print_best = false;
|
||||
|
||||
MMParallel parallel;
|
||||
MMStorage storage;
|
||||
MMKeys keys;
|
||||
std::vector<MMPerKey> per_key;
|
||||
|
|
|
|||
Loading…
Reference in New Issue