Refactor: move Worker to ThreadingContext, factor out MMDecompress

PiperOrigin-RevId: 804909921
This commit is contained in:
Jan Wassenberg 2025-09-09 07:55:39 -07:00 committed by Copybara-Service
parent 461a9c7d1b
commit 24b1760f03
3 changed files with 199 additions and 194 deletions

View File

@ -219,6 +219,181 @@ class MMStoreHorizontalSumsIntoC {
}
}; // MMStoreHorizontalSumsIntoC
// Stateless, wraps member functions.
class MMDecompress {
public:
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
// thanks to its large table lookups, and less so on other targets.
template <typename TB>
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const StridedViewBF B_view) {
const hn::ScalableTag<BF16> dbf;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
// Neither A nor B require padding because `LoopKC` handles remainders.
if constexpr (hwy::IsSame<TB, BF16>()) {
return View(B, row_b, range_kc.begin(), range_kc.Num());
}
const PackedSpan<const TB> B_span = B.PaddedSpan();
const size_t kc = range_kc.Num();
const size_t col0 = range_kc.begin();
for (size_t r = 0; r < kNR; ++r) {
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
BF16* HWY_RESTRICT to = B_view.Row(r);
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
}
return B_view;
}
template <typename TA>
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
MMAutoTune<MMParA>& autotune,
const MatMulEnv& env,
MMOptions options) {
if constexpr (IsBF16<TA>()) {
// We can use a view, regardless of columns/padding, because
// `MMKernel::LoopKC` supports non-vector multiples.
return View(A, 0, 0, A.Cols());
} else {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16. We also only
// have a single MMStorage.
HWY_ASSERT(options.cluster_idx == 0);
const StridedViewBF A_view = env.storage.A(A.Extents());
AutotuneDecompressA(A, A_view, autotune, env, options);
return A_view;
}
}
private:
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMAutoTune<MMParA>& autotune,
MMParA par_a, const MatMulEnv& env,
const MMOptions& options) {
const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_view.Cols());
const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf);
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
const auto do_range =
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, env, &autotune);
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// Must be a vector multiple, or the last range before row
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const float> from =
MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
}
};
switch (par_a) {
case MMParA::kNone:
do_range(all_M, all_K, env.ctx.Worker(options.cluster_idx));
break;
case MMParA::kK1:
case MMParA::kK2:
case MMParA::kK4: {
const size_t inner_tasks = static_cast<size_t>(par_a);
// At least one vector, otherwise DecompressAndZeroPad will add
// padding, which might overwrite neighboring tasks. Also a whole cache
// line to avoid false sharing.
const size_t multiple_K =
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
options.cluster_idx,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
});
});
break;
}
case MMParA::kM:
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
[&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K,
worker);
});
});
break;
}
}
// Autotuning wrapper for `DoDecompressA`.
static HWY_INLINE void AutotuneDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMAutoTune<MMParA>& autotune,
const MatMulEnv& env,
const MMOptions& options) {
if (HWY_LIKELY(autotune.Best())) {
return DecompressA(A, A_view, autotune, *autotune.Best(), env, options);
}
// First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
other};
autotune.SetCandidates(candidates);
}
const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start();
DecompressA(A, A_view, autotune, par_a, env, options);
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
static_cast<double>(min_elapsed) /
hwy::platform::InvariantTicksPerSecond() * 1E6);
}
}
}; // MMDecompress
// Stateless, wraps member functions. Contains the innermost 2-4 loops.
class MMKernel {
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
@ -288,49 +463,6 @@ class MMKernel {
static constexpr size_t B_storage_max = kNR * B_stride_max;
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
// col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL`
// thanks to its large table lookups, and less so on other targets.
template <typename TB>
static StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const StridedViewBF B_view) {
const hn::ScalableTag<BF16> dbf;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
// Neither A nor B require padding because `LoopKC` handles remainders.
if constexpr (hwy::IsSame<TB, BF16>()) {
return View(B, row_b, range_kc.begin(), range_kc.Num());
}
const PackedSpan<const TB> B_span = B.PaddedSpan();
const size_t kc = range_kc.Num();
const size_t col0 = range_kc.begin();
for (size_t r = 0; r < kNR; ++r) {
const size_t packed_ofs = (row_b + r) * B.Stride() + col0;
BF16* HWY_RESTRICT to = B_view.Row(r);
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
}
return B_view;
}
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
// `ForeachKC` and when there is only a single KC task.
@ -350,7 +482,8 @@ class MMKernel {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
StridedViewBF B_view =
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows);
}
}
@ -742,137 +875,6 @@ class MMImpl {
HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes));
}
}
static size_t Worker(const MatMulEnv& env, size_t cluster_idx) {
return cluster_idx * env.ctx.pools.MaxWorkersPerCluster();
}
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
static HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMAutoTune<MMParA>& autotune,
MMParA par_a, const MatMulEnv& env,
const MMOptions& options) {
const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_view.Cols());
const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf);
static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA");
const auto do_range =
[&](const IndexRange& range_M, const IndexRange& range_K, size_t worker)
HWY_ATTR {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, env, &autotune);
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// Must be a vector multiple, or the last range before row
// padding, otherwise `DecompressAndZeroPad` overwrites neighbors.
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const float> from =
MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
}
};
switch (par_a) {
case MMParA::kNone:
do_range(all_M, all_K, Worker(env, options.cluster_idx));
break;
case MMParA::kK1:
case MMParA::kK2:
case MMParA::kK4: {
const size_t inner_tasks = static_cast<size_t>(par_a);
// At least one vector, otherwise DecompressAndZeroPad will add
// padding, which might overwrite neighboring tasks. Also a whole cache
// line to avoid false sharing.
const size_t multiple_K =
HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16));
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks,
options.cluster_idx,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
});
});
break;
}
case MMParA::kM:
DispatchParallelism(options.parallelism, [&](const auto& parallel) {
parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx,
[&](size_t row_a, size_t worker) {
do_range(IndexRange(row_a, row_a + 1), all_K,
worker);
});
});
break;
}
}
// Autotuning wrapper for `DoDecompressA`.
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMAutoTune<MMParA>& autotune,
const MatMulEnv& env,
const MMOptions& options) {
if (HWY_LIKELY(autotune.Best())) {
return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options);
}
// First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM;
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4,
other};
autotune.SetCandidates(candidates);
}
const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start();
DoDecompressA(A, A_view, autotune, par_a, env, options);
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) {
fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a),
static_cast<double>(min_elapsed) /
hwy::platform::InvariantTicksPerSecond() * 1E6);
}
}
template <typename TA>
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
MMAutoTune<MMParA>& autotune,
const MatMulEnv& env,
MMOptions options) {
if constexpr (IsBF16<TA>()) {
// We can use a view, regardless of columns/padding, because `LoopKC`
// supports non-vector multiples.
return MMKernel::View(A, 0, 0, A.Cols());
} else {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16. We also only
// have a single MMStorage.
HWY_ASSERT(options.cluster_idx == 0);
const StridedViewBF A_view = env.storage.A(A.Extents());
DecompressA(A, A_view, autotune, env, options);
return A_view;
}
}
};
// Defines several variants of the outer M/N/K loops (see `MMOrder`).
@ -885,7 +887,7 @@ class MMLoops {
RowPtrs<TC> C_rows, const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
PROFILER_ZONE3(args.env.ctx.profiler,
MMImpl::Worker(args.env, args.options.cluster_idx), zone);
args.env.ctx.Worker(args.options.cluster_idx), zone);
DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
@ -931,7 +933,7 @@ class MMLoops {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view =
MMKernel::DecompressB(B, row_b, range_K, B_storage_view);
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(),
args, C_rows);
}
@ -1026,8 +1028,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
const size_t cluster_idx = options.cluster_idx;
HWY_DASSERT(cluster_idx < env.row_ptrs.size());
PROFILER_ZONE3(env.ctx.profiler,
cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone);
PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone);
RowPtrs<TC> C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]);
@ -1041,7 +1042,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// (Also auto-tunes, hence outside the timed section to prevent interference.)
const StridedViewBF A_view =
MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options);
MMAutoTune<MMConfig>& tuner = per_key.autotune;
if (HWY_LIKELY(tuner.Best())) {

View File

@ -116,7 +116,7 @@ struct MMParallelNone {
size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t worker = ctx.Worker(cluster_idx);
func(range_n, worker);
}
@ -125,7 +125,7 @@ struct MMParallelNone {
const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t cluster_idx,
const Func& func) const {
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t worker = ctx.Worker(cluster_idx);
for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) {
const IndexRange range_mc = ranges_mc.Range(i);
@ -139,7 +139,7 @@ struct MMParallelNone {
template <class Func>
void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc,
size_t cluster_idx, const Func& func) const {
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t worker = ctx.Worker(cluster_idx);
for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) {
func(row_a, worker);
}
@ -154,7 +154,7 @@ struct MMParallelWithinCluster {
const size_t pkg_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t base = ctx.Worker(cluster_idx);
const IndexRangePartition worker_ranges = StaticPartition(
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
@ -171,7 +171,7 @@ struct MMParallelWithinCluster {
const Func& func) const {
const size_t pkg_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t base = ctx.Worker(cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
@ -192,7 +192,7 @@ struct MMParallelWithinCluster {
size_t cluster_idx, const Func& func) const {
const size_t pkg_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t base = ctx.Worker(cluster_idx);
cluster.Run(
range_mc.begin(), range_mc.end(),
@ -233,8 +233,7 @@ struct MMParallelHierarchical {
n_ranges, all_clusters,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t cluster_base = ctx.Worker(cluster_idx);
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
@ -284,8 +283,7 @@ struct MMParallelHierarchical {
ParallelizeOneRange(
ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) {
const size_t cluster_base =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t cluster_base = ctx.Worker(cluster_idx);
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange(ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t worker) {

View File

@ -97,6 +97,13 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
struct ThreadingContext {
explicit ThreadingContext(const ThreadingArgs& args);
// Returns a worker index compatible with those from `ParallelFor`, assuming
// the current thread is running on one thread per cluster, which happens
// when `ParallelismStrategy` is `kAcrossClusters`.
size_t Worker(size_t cluster_idx) const {
return cluster_idx * pools.MaxWorkersPerCluster();
}
// Singleton; pass around a reference to reduce overhead.
hwy::Profiler& profiler;
@ -158,7 +165,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
switch (parallelism) {
case ParallelismStrategy::kNone: {
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t worker = ctx.Worker(cluster_idx);
for (size_t task = 0; task < num_tasks; ++task) {
func(task, worker);
}
@ -173,7 +180,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
case ParallelismStrategy::kWithinCluster: {
// Ensure the worker argument is unique across clusters, because it is
// used for TLS indexing for example in profiler.h.
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t base = ctx.Worker(cluster_idx);
return ctx.pools.Cluster(pkg_idx, cluster_idx)
.Run(0, num_tasks, [&](uint64_t task, size_t worker) {
func(task, base + worker);
@ -193,8 +200,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks,
return ctx.pools.AllClusters(pkg_idx).Run(
0, num_tasks, [&](uint64_t task, size_t cluster_idx) {
const size_t worker =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
const size_t worker = ctx.Worker(cluster_idx);
func(task, worker);
});
}