mirror of https://github.com/google/gemma.cpp.git
Refactor Matmul to use a policy class for parallelization.
PiperOrigin-RevId: 800864489
This commit is contained in:
parent
6c39a2dea4
commit
973e284ed6
|
|
@ -320,8 +320,6 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||||
const size_t start = owners.size();
|
const size_t start = owners.size();
|
||||||
owners.resize(start + tensors.size());
|
owners.resize(start + tensors.size());
|
||||||
|
|
||||||
MMParallel parallel(ctx);
|
|
||||||
|
|
||||||
// Allocate in parallel because faulting in large tensors is slow.
|
// Allocate in parallel because faulting in large tensors is slow.
|
||||||
ctx.pools.Pool().Run(
|
ctx.pools.Pool().Run(
|
||||||
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
0, tensors.size(), [&](uint64_t task, size_t /*thread*/) {
|
||||||
|
|
@ -339,7 +337,7 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
|
||||||
|
|
||||||
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
|
owners[start + task].AllocateFor(*tensor.mat, ctx.allocator,
|
||||||
tensor.padding);
|
tensor.padding);
|
||||||
BindB(*tensor.mat, tensor.mat->ElementBytes(), parallel);
|
BindB(ctx, *tensor.mat, tensor.mat->ElementBytes());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
108
ops/matmul-inl.h
108
ops/matmul-inl.h
|
|
@ -779,17 +779,19 @@ class MMPerPackage {
|
||||||
// B and maybe A are decompressed several call layers lower, but not all
|
// B and maybe A are decompressed several call layers lower, but not all
|
||||||
// member functions depend on TA/TB, so pass them as an argument instead of
|
// member functions depend on TA/TB, so pass them as an argument instead of
|
||||||
// templating the class.
|
// templating the class.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||||
HWY_NOINLINE void operator()(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
|
||||||
|
const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows) const {
|
RowPtrs<TC> C_rows) const {
|
||||||
if constexpr (WantDecompressA<TA>()) {
|
if constexpr (WantDecompressA<TA>()) {
|
||||||
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
||||||
DecompressA(A, A_view);
|
DecompressA<MMParallelPolicyT>(A, A_view);
|
||||||
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
|
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
|
||||||
DispatchOrder(A_view, A_padded, B, C_rows);
|
DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
|
||||||
} else {
|
} else {
|
||||||
const bool A_padded = HasPadding(A);
|
const bool A_padded = HasPadding(A);
|
||||||
DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows);
|
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
|
||||||
|
C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -828,28 +830,30 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
|
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||||
HWY_INLINE void DispatchOrder(const StridedView<TA> A, const bool A_padded,
|
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
|
||||||
|
const StridedView<TA> A, const bool A_padded,
|
||||||
const MatPtrT<TB>& B,
|
const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows) const {
|
RowPtrs<TC> C_rows) const {
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case MMOrder::kNT:
|
case MMOrder::kNT:
|
||||||
return DoNT(A, A_padded, B, C_rows);
|
return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||||
case MMOrder::kNT_K:
|
case MMOrder::kNT_K:
|
||||||
return DoNT_K(A, A_padded, B, C_rows);
|
return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||||
case MMOrder::kNT_MT:
|
case MMOrder::kNT_MT:
|
||||||
return DoNT_MT(A, A_padded, B, C_rows);
|
return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||||
case MMOrder::kNT_MT_K:
|
case MMOrder::kNT_MT_K:
|
||||||
return DoNT_MT_K(A, A_padded, B, C_rows);
|
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||||
default:
|
default:
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||||
HWY_INLINE void DoNT(const StridedView<TA> A, const bool A_padded,
|
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const bool A_padded, const MatPtrT<TB>& B,
|
||||||
|
RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
|
|
@ -861,9 +865,9 @@ class MMPerPackage {
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
||||||
|
|
||||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||||
args_.env->parallel.ForNP(
|
MMParallelPolicyT::ForNP(
|
||||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
|
|
@ -881,9 +885,10 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||||
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded,
|
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const bool A_padded, const MatPtrT<TB>& B,
|
||||||
|
RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
const IndexRange& range_mc = ranges_mc_.Range(0);
|
||||||
|
|
@ -909,9 +914,9 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
args_.env->parallel.ForNP(
|
MMParallelPolicyT::ForNP(
|
||||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
|
||||||
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
pkg_idx_, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(worker, zone, args_);
|
mm_zone.MaybeEnter(worker, zone, args_);
|
||||||
|
|
||||||
|
|
@ -930,9 +935,10 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||||
HWY_INLINE void DoNT_MT(const StridedView<TA> A, const bool A_padded,
|
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const bool A_padded, const MatPtrT<TB>& B,
|
||||||
|
RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||||
|
|
@ -942,8 +948,8 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||||
// except for the profiler strings and `out_tag`.
|
// except for the profiler strings and `out_tag`.
|
||||||
args_.env->parallel.ForRangesMC_NC(
|
MMParallelPolicyT::ForRangesMC_NC(
|
||||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
|
|
@ -965,9 +971,10 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||||
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded,
|
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const bool A_padded, const MatPtrT<TB>& B,
|
||||||
|
RowPtrs<TC> C_rows) const {
|
||||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||||
|
|
@ -992,8 +999,8 @@ class MMPerPackage {
|
||||||
out_tag, args_, C_rows);
|
out_tag, args_, C_rows);
|
||||||
}
|
}
|
||||||
}; // loop_nc
|
}; // loop_nc
|
||||||
args_.env->parallel.ForRangesMC_NC(
|
MMParallelPolicyT::ForRangesMC_NC(
|
||||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
args_.env->ctx, ranges_mc_, ranges_nc_, pkg_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
size_t worker) HWY_ATTR {
|
size_t worker) HWY_ATTR {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
|
|
@ -1014,6 +1021,7 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||||
|
template <typename MMParallelPolicyT>
|
||||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
||||||
const StridedViewBF A_view,
|
const StridedViewBF A_view,
|
||||||
MMParA par_a) const {
|
MMParA par_a) const {
|
||||||
|
|
@ -1064,16 +1072,16 @@ class MMPerPackage {
|
||||||
// line to avoid false sharing.
|
// line to avoid false sharing.
|
||||||
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
||||||
|
|
||||||
args_.env->parallel.ForNP(
|
MMParallelPolicyT::ForNP(args_.env->ctx, all_K, multiple_K, inner_tasks,
|
||||||
all_K, multiple_K, inner_tasks, pkg_idx_,
|
pkg_idx_,
|
||||||
[&](const IndexRange& range_K, size_t worker) {
|
[&](const IndexRange& range_K, size_t worker) {
|
||||||
do_range(all_M, range_K, worker);
|
do_range(all_M, range_K, worker);
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case MMParA::kM:
|
case MMParA::kM:
|
||||||
args_.env->parallel.ForRangeMC(
|
MMParallelPolicyT::ForRangeMC(
|
||||||
all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
args_.env->ctx, all_M, pkg_idx_, [&](size_t row_a, size_t worker) {
|
||||||
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
do_range(IndexRange(row_a, row_a + 1), all_K, worker);
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
|
|
@ -1081,12 +1089,13 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Autotuning wrapper for `DoDecompressA`.
|
// Autotuning wrapper for `DoDecompressA`.
|
||||||
|
template <typename MMParallelPolicyT>
|
||||||
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
||||||
const StridedViewBF A_view) const {
|
const StridedViewBF A_view) const {
|
||||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||||
|
|
||||||
if (HWY_LIKELY(autotune.Best())) {
|
if (HWY_LIKELY(autotune.Best())) {
|
||||||
return DoDecompressA(A, A_view, *autotune.Best());
|
return DoDecompressA<MMParallelPolicyT>(A, A_view, *autotune.Best());
|
||||||
}
|
}
|
||||||
|
|
||||||
// First call: generate candidates.
|
// First call: generate candidates.
|
||||||
|
|
@ -1099,7 +1108,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
const MMParA& par_a = autotune.NextConfig();
|
const MMParA& par_a = autotune.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
DoDecompressA(A, A_view, par_a);
|
DoDecompressA<MMParallelPolicyT>(A, A_view, par_a);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||||
|
|
@ -1185,19 +1194,21 @@ struct MMImpl {
|
||||||
|
|
||||||
if constexpr (kMaxPackages > 1) {
|
if constexpr (kMaxPackages > 1) {
|
||||||
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
// Outermost loop: static NUMA-aware partition of B rows across packages.
|
||||||
args.env->parallel.ForPkg(
|
MMNestedParallelPolicy::ForPkg(
|
||||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
args.env->ctx, args.per_key->ranges_np.NumTasks(),
|
||||||
|
[&](size_t pkg_idx) {
|
||||||
MMZone mm_zone;
|
MMZone mm_zone;
|
||||||
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B,
|
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||||
C_rows);
|
MMNestedParallelPolicy(), A, B, C_rows);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
const size_t pkg_idx = 0;
|
const size_t pkg_idx = 0;
|
||||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows);
|
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(
|
||||||
|
MMNestedParallelPolicy(), A, B, C_rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1250,8 +1261,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
|
|
||||||
// invalidates `MMAutoTune::Best()`
|
// invalidates `MMAutoTune::Best()`
|
||||||
index = env.per_key.size();
|
index = env.per_key.size();
|
||||||
env.per_key.push_back(
|
env.per_key.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
|
||||||
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
|
|
||||||
}
|
}
|
||||||
MMPerKey& per_key = env.per_key[index];
|
MMPerKey& per_key = env.per_key[index];
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
|
|
||||||
|
|
@ -397,34 +397,33 @@ static size_t NPMultiple(const Allocator& allocator, size_t N,
|
||||||
return np_multiple;
|
return np_multiple;
|
||||||
}
|
}
|
||||||
|
|
||||||
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||||
size_t sizeof_TC, size_t nr) const {
|
size_t N, size_t sizeof_TC, size_t nr) {
|
||||||
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
|
const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages());
|
||||||
return StaticPartition(
|
return StaticPartition(
|
||||||
IndexRange(0, N), num_packages,
|
IndexRange(0, N), num_packages,
|
||||||
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
|
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
|
||||||
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
|
|
||||||
char cpu100[100];
|
char cpu100[100];
|
||||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
|
|
||||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
|
||||||
}
|
}
|
||||||
|
|
||||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
|
||||||
Allocator& allocator = parallel.allocator();
|
Allocator& allocator = ctx.allocator;
|
||||||
if (!allocator.ShouldBind()) return;
|
if (!allocator.ShouldBind()) return;
|
||||||
if (B.Rows() == 1) return;
|
if (B.Rows() == 1) return;
|
||||||
|
|
||||||
PROFILER_ZONE("Startup.BindB");
|
PROFILER_ZONE("Startup.BindB");
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR);
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||||
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
|
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
|
||||||
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
|
||||||
// B row padding is less than the page size, so only bind the subset that
|
// B row padding is less than the page size, so only bind the subset that
|
||||||
|
|
@ -438,14 +437,14 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// C is BF16/float
|
// C is BF16/float
|
||||||
void BindC(MatPtr& C, MMParallel& parallel) {
|
void BindC(ThreadingContext& ctx, MatPtr& C) {
|
||||||
Allocator& allocator = parallel.allocator();
|
Allocator& allocator = ctx.allocator;
|
||||||
if (!allocator.ShouldBind()) return;
|
if (!allocator.ShouldBind()) return;
|
||||||
|
|
||||||
PROFILER_ZONE("Startup.BindC");
|
PROFILER_ZONE("Startup.BindC");
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||||
|
|
@ -455,7 +454,7 @@ void BindC(MatPtr& C, MMParallel& parallel) {
|
||||||
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
|
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
|
||||||
allocator.BasePageBytes());
|
allocator.BasePageBytes());
|
||||||
|
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||||
for (size_t im = 0; im < C.Rows(); ++im) {
|
for (size_t im = 0; im < C.Rows(); ++im) {
|
||||||
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
|
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
107
ops/matmul.h
107
ops/matmul.h
|
|
@ -53,35 +53,31 @@ constexpr size_t kNR = 4;
|
||||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||||
static constexpr size_t kMaxMR = 4;
|
static constexpr size_t kMaxMR = 4;
|
||||||
|
|
||||||
// Mostly stateless, can be constructed on the fly by weights.cc. Captures the
|
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
|
||||||
// the ThreadingContext to shorten call sites.
|
size_t N, size_t sizeof_TC, size_t nr);
|
||||||
class MMParallel {
|
|
||||||
public:
|
|
||||||
// `ctx` must outlive this object.
|
|
||||||
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
|
|
||||||
if (ctx_.pools.NumPackages() > kMaxPackages) {
|
|
||||||
HWY_WARN("CPU and --max_packages allow %zu > matmul.h kMaxPackages %zu.",
|
|
||||||
ctx_.pools.NumPackages(), kMaxPackages);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Allocator& allocator() const { return ctx_.allocator; }
|
enum class ParallelismType : uint8_t {
|
||||||
|
kNone,
|
||||||
|
// No parallelism.
|
||||||
|
kSequential,
|
||||||
|
// Parallelism at cluster level.
|
||||||
|
kCluster,
|
||||||
|
// Parallelism at package level.
|
||||||
|
kNested,
|
||||||
|
};
|
||||||
|
|
||||||
// Initial static partitioning of B rows across packages.
|
struct MMOptions {
|
||||||
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
ParallelismType parallelism_type_ = ParallelismType::kNested;
|
||||||
size_t sizeof_TC, size_t nr) const;
|
uint8_t cluster_idx_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// For `BindB` and `BindC`.
|
struct MMNestedParallelPolicy {
|
||||||
size_t Node(size_t pkg_idx) const {
|
|
||||||
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls `func(pkg_idx)` for each package in parallel.
|
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForPkg(const size_t max_packages, const Func& func) {
|
static void ForPkg(ThreadingContext& ctx, const size_t max_packages,
|
||||||
|
const Func& func) {
|
||||||
if constexpr (kMaxPackages > 1) {
|
if constexpr (kMaxPackages > 1) {
|
||||||
ctx_.pools.AllPackages().Run(
|
ctx.pools.AllPackages().Run(
|
||||||
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
|
0, HWY_MIN(max_packages, ctx.pools.NumPackages()),
|
||||||
[&](uint64_t task, size_t pkg_idx) {
|
[&](uint64_t task, size_t pkg_idx) {
|
||||||
HWY_DASSERT(task == pkg_idx);
|
HWY_DASSERT(task == pkg_idx);
|
||||||
(void)task;
|
(void)task;
|
||||||
|
|
@ -95,16 +91,17 @@ class MMParallel {
|
||||||
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
|
||||||
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForNP(const IndexRange& range_np, size_t nx_multiple, size_t inner_tasks,
|
static void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
|
||||||
size_t pkg_idx, const Func& func) {
|
size_t nx_multiple, size_t inner_tasks, size_t pkg_idx,
|
||||||
|
const Func& func) {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
|
|
||||||
// Single cluster: parallel-for over static partition of `range_np`.
|
// Single cluster: parallel-for over static partition of `range_np`.
|
||||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
if (num_clusters == 1) {
|
if (num_clusters == 1) {
|
||||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, 0);
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
|
|
@ -120,9 +117,9 @@ class MMParallel {
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
nx_ranges, all_clusters,
|
nx_ranges, all_clusters,
|
||||||
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
const size_t cluster_base =
|
const size_t cluster_base =
|
||||||
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
|
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
|
|
@ -137,18 +134,19 @@ class MMParallel {
|
||||||
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
// Cluster/CCX-aware parallel-for over blocks (separate subranges of A and B
|
||||||
// rows). Calls `func(range_mc, range_nc, worker)`.
|
// rows). Calls `func(range_mc, range_nc, worker)`.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
static void ForRangesMC_NC(ThreadingContext& ctx,
|
||||||
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
const IndexRangePartition& ranges_mc,
|
||||||
const Func& func) {
|
const IndexRangePartition& ranges_nc,
|
||||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
size_t pkg_idx, const Func& func) {
|
||||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
|
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
|
||||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
// Single (big) cluster: collapse two range indices into one parallel-for
|
// Single (big) cluster: collapse two range indices into one parallel-for
|
||||||
// to reduce the number of fork-joins.
|
// to reduce the number of fork-joins.
|
||||||
if (num_clusters == 1) {
|
if (num_clusters == 1) {
|
||||||
const size_t cluster_idx = 0;
|
const size_t cluster_idx = 0;
|
||||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
// Low-batch: avoid Divide/Remainder.
|
// Low-batch: avoid Divide/Remainder.
|
||||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
|
|
@ -171,8 +169,8 @@ class MMParallel {
|
||||||
ranges_nc, all_clusters,
|
ranges_nc, all_clusters,
|
||||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||||
const size_t cluster_base =
|
const size_t cluster_base =
|
||||||
pkg_base + cluster_idx * ctx_.pools.MaxWorkersPerCluster();
|
pkg_base + cluster_idx * ctx.pools.MaxWorkersPerCluster();
|
||||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
ParallelizeOneRange(ranges_mc, cluster,
|
ParallelizeOneRange(ranges_mc, cluster,
|
||||||
[&](const IndexRange& range_mc, size_t thread) {
|
[&](const IndexRange& range_mc, size_t thread) {
|
||||||
func(range_mc, range_nc, cluster_base + thread);
|
func(range_mc, range_nc, cluster_base + thread);
|
||||||
|
|
@ -182,21 +180,18 @@ class MMParallel {
|
||||||
|
|
||||||
// Calls `func(row_a, worker)` in parallel.
|
// Calls `func(row_a, worker)` in parallel.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
|
||||||
const Func& func) {
|
size_t pkg_idx, const Func& func) {
|
||||||
const size_t pkg_base = pkg_idx * ctx_.pools.MaxWorkersPerPackage();
|
const size_t pkg_base = pkg_idx * ctx.pools.MaxWorkersPerPackage();
|
||||||
ctx_.pools.Pool(pkg_idx).Run(
|
ctx.pools.Pool(pkg_idx).Run(
|
||||||
range_mc.begin(), range_mc.end(),
|
range_mc.begin(), range_mc.end(),
|
||||||
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
|
[&](uint64_t row_a, size_t thread) { func(row_a, pkg_base + thread); });
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
ThreadingContext& ctx_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel);
|
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
||||||
// C is BF16/float.
|
// C is BF16/float.
|
||||||
void BindC(MatPtr& C, MMParallel& parallel);
|
void BindC(ThreadingContext& ctx, MatPtr& C);
|
||||||
|
|
||||||
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
||||||
#pragma pack(push, 1) // power of two size
|
#pragma pack(push, 1) // power of two size
|
||||||
|
|
@ -250,15 +245,18 @@ class MMStorage {
|
||||||
|
|
||||||
// Internally threaded; must not be called concurrently with the same
|
// Internally threaded; must not be called concurrently with the same
|
||||||
// `ThreadingContext` (used via `parallel`).
|
// `ThreadingContext` (used via `parallel`).
|
||||||
MMStorage(const Allocator& allocator, MMParallel& parallel) {
|
MMStorage(ThreadingContext& ctx) {
|
||||||
// Per-package allocation so each can decompress A into its own copy.
|
// Per-package allocation so each can decompress A into its own copy.
|
||||||
// Must be padded, see `DoDecompressA`.
|
// Must be padded, see `DoDecompressA`.
|
||||||
parallel.ForPkg(kMaxPackages, [&](size_t pkg_idx) {
|
// Default to nested parallel policy.
|
||||||
|
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
|
||||||
|
Allocator& allocator = ctx.allocator;
|
||||||
|
|
||||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||||
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
||||||
|
|
||||||
if (allocator.ShouldBind()) {
|
if (allocator.ShouldBind()) {
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||||
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
|
||||||
pkg_A_[pkg_idx]->ElementBytes();
|
pkg_A_[pkg_idx]->ElementBytes();
|
||||||
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
|
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
|
||||||
|
|
@ -607,9 +605,9 @@ class MMKeys {
|
||||||
|
|
||||||
// Per-MatMul-shape state.
|
// Per-MatMul-shape state.
|
||||||
struct MMPerKey {
|
struct MMPerKey {
|
||||||
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
|
MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N,
|
||||||
MMParallel& parallel)
|
size_t sizeof_TC, size_t nr)
|
||||||
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {
|
: ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) {
|
||||||
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
|
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -639,7 +637,6 @@ struct MatMulEnv {
|
||||||
// Whether to print the best config immediately after autotuning finished.
|
// Whether to print the best config immediately after autotuning finished.
|
||||||
bool print_best = false;
|
bool print_best = false;
|
||||||
|
|
||||||
MMParallel parallel;
|
|
||||||
MMStorage storage;
|
MMStorage storage;
|
||||||
MMKeys keys;
|
MMKeys keys;
|
||||||
std::vector<MMPerKey> per_key;
|
std::vector<MMPerKey> per_key;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue