diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 21ac14b..bf7bd68 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -219,6 +219,181 @@ class MMStoreHorizontalSumsIntoC { } }; // MMStoreHorizontalSumsIntoC +// Stateless, wraps member functions. +class MMDecompress { + public: + // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, + // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` + // thanks to its large table lookups, and less so on other targets. + template + static StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, + const StridedViewBF B_view) { + const hn::ScalableTag dbf; + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); + + // Neither A nor B require padding because `LoopKC` handles remainders. + if constexpr (hwy::IsSame()) { + return View(B, row_b, range_kc.begin(), range_kc.Num()); + } + + const PackedSpan B_span = B.PaddedSpan(); + + const size_t kc = range_kc.Num(); + const size_t col0 = range_kc.begin(); + + for (size_t r = 0; r < kNR; ++r) { + const size_t packed_ofs = (row_b + r) * B.Stride() + col0; + BF16* HWY_RESTRICT to = B_view.Row(r); + DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + return B_view; + } + + template + static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, + MMAutoTune& autotune, + const MatMulEnv& env, + MMOptions options) { + if constexpr (IsBF16()) { + // We can use a view, regardless of columns/padding, because + // `MMKernel::LoopKC` supports non-vector multiples. + return View(A, 0, 0, A.Cols()); + } else { + // Always decompress. To reduce code size/compile time, we no longer + // support a separate F32 kernel; most A are already BF16. We also only + // have a single MMStorage. + HWY_ASSERT(options.cluster_idx == 0); + const StridedViewBF A_view = env.storage.A(A.Extents()); + AutotuneDecompressA(A, A_view, autotune, env, options); + return A_view; + } + } + + private: + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + static HWY_NOINLINE void DecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMAutoTune& autotune, + MMParA par_a, const MatMulEnv& env, + const MMOptions& options) { + const IndexRange all_M(0, A.Rows()); + const IndexRange all_K(0, A.Cols()); + HWY_DASSERT(all_K.Num() == A_view.Cols()); + + const hn::ScalableTag dbf; + const size_t NBF = hn::Lanes(dbf); + + static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); + + const auto do_range = + [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) + HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, env, &autotune); + + const size_t col0 = range_K.begin(); + const size_t cols = range_K.Num(); + // Must be a vector multiple, or the last range before row + // padding, otherwise `DecompressAndZeroPad` overwrites neighbors. + HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); + for (size_t row_a : range_M) { + const PackedSpan from = + MakeSpan(A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; + DecompressAndZeroPad(dbf, from, 0, to, cols); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + }; + + switch (par_a) { + case MMParA::kNone: + do_range(all_M, all_K, env.ctx.Worker(options.cluster_idx)); + break; + + case MMParA::kK1: + case MMParA::kK2: + case MMParA::kK4: { + const size_t inner_tasks = static_cast(par_a); + // At least one vector, otherwise DecompressAndZeroPad will add + // padding, which might overwrite neighboring tasks. Also a whole cache + // line to avoid false sharing. + const size_t multiple_K = + HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16)); + + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks, + options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); + }); + break; + } + case MMParA::kM: + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx, + [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, + worker); + }); + }); + break; + } + } + + // Autotuning wrapper for `DoDecompressA`. + static HWY_INLINE void AutotuneDecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMAutoTune& autotune, + const MatMulEnv& env, + const MMOptions& options) { + if (HWY_LIKELY(autotune.Best())) { + return DecompressA(A, A_view, autotune, *autotune.Best(), env, options); + } + + // First call: generate candidates. + if (HWY_UNLIKELY(!autotune.HasCandidates())) { + const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; + std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, + other}; + autotune.SetCandidates(candidates); + } + + const MMParA& par_a = autotune.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DecompressA(A, A_view, autotune, par_a, env, options); + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); + if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) { + fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), + static_cast(min_elapsed) / + hwy::platform::InvariantTicksPerSecond() * 1E6); + } + } +}; // MMDecompress + // Stateless, wraps member functions. Contains the innermost 2-4 loops. class MMKernel { // Compute size of per-worker storage for `kNR` row ranges of B. Stack @@ -288,49 +463,6 @@ class MMKernel { static constexpr size_t B_storage_max = kNR * B_stride_max; - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - - // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, - // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` - // thanks to its large table lookups, and less so on other targets. - template - static StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, - const IndexRange& range_kc, - const StridedViewBF B_view) { - const hn::ScalableTag dbf; - HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); - - // Neither A nor B require padding because `LoopKC` handles remainders. - if constexpr (hwy::IsSame()) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); - } - - const PackedSpan B_span = B.PaddedSpan(); - - const size_t kc = range_kc.Num(); - const size_t col0 = range_kc.begin(); - - for (size_t r = 0; r < kNR; ++r) { - const size_t packed_ofs = (row_b + r) * B.Stride() + col0; - BF16* HWY_RESTRICT to = B_view.Row(r); - DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - return B_view; - } - // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by // `ForeachKC` and when there is only a single KC task. @@ -350,7 +482,8 @@ class MMKernel { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); + StridedViewBF B_view = + MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); } } @@ -742,137 +875,6 @@ class MMImpl { HWY_ASSERT(hwy::IsAligned(A.RowBytes(1), vector_bytes)); } } - - static size_t Worker(const MatMulEnv& env, size_t cluster_idx) { - return cluster_idx * env.ctx.pools.MaxWorkersPerCluster(); - } - - // Decompresses all `M x K` from `A` into padded BF16 `A_view`. - static HWY_NOINLINE void DoDecompressA(const MatPtrT& A, - const StridedViewBF A_view, - MMAutoTune& autotune, - MMParA par_a, const MatMulEnv& env, - const MMOptions& options) { - const IndexRange all_M(0, A.Rows()); - const IndexRange all_K(0, A.Cols()); - HWY_DASSERT(all_K.Num() == A_view.Cols()); - - const hn::ScalableTag dbf; - const size_t NBF = hn::Lanes(dbf); - - static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); - - const auto do_range = - [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) - HWY_ATTR { - MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, env, &autotune); - - const size_t col0 = range_K.begin(); - const size_t cols = range_K.Num(); - // Must be a vector multiple, or the last range before row - // padding, otherwise `DecompressAndZeroPad` overwrites neighbors. - HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); - for (size_t row_a : range_M) { - const PackedSpan from = - MakeSpan(A.Row(row_a) + col0, cols); - BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; - DecompressAndZeroPad(dbf, from, 0, to, cols); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = cols; i < hwy::RoundUpTo(cols, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - }; - - switch (par_a) { - case MMParA::kNone: - do_range(all_M, all_K, Worker(env, options.cluster_idx)); - break; - - case MMParA::kK1: - case MMParA::kK2: - case MMParA::kK4: { - const size_t inner_tasks = static_cast(par_a); - // At least one vector, otherwise DecompressAndZeroPad will add - // padding, which might overwrite neighboring tasks. Also a whole cache - // line to avoid false sharing. - const size_t multiple_K = - HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16)); - - DispatchParallelism(options.parallelism, [&](const auto& parallel) { - parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks, - options.cluster_idx, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); - }); - break; - } - case MMParA::kM: - DispatchParallelism(options.parallelism, [&](const auto& parallel) { - parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx, - [&](size_t row_a, size_t worker) { - do_range(IndexRange(row_a, row_a + 1), all_K, - worker); - }); - }); - break; - } - } - - // Autotuning wrapper for `DoDecompressA`. - static HWY_INLINE void DecompressA(const MatPtrT& A, - const StridedViewBF A_view, - MMAutoTune& autotune, - const MatMulEnv& env, - const MMOptions& options) { - if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options); - } - - // First call: generate candidates. - if (HWY_UNLIKELY(!autotune.HasCandidates())) { - const MMParA other = (A.Rows() == 1) ? MMParA::kNone : MMParA::kM; - std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4, - other}; - autotune.SetCandidates(candidates); - } - - const MMParA& par_a = autotune.NextConfig(); - const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, autotune, par_a, env, options); - const uint64_t t1 = - env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); - const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); - if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) { - fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), - static_cast(min_elapsed) / - hwy::platform::InvariantTicksPerSecond() * 1E6); - } - } - - template - static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, - MMAutoTune& autotune, - const MatMulEnv& env, - MMOptions options) { - if constexpr (IsBF16()) { - // We can use a view, regardless of columns/padding, because `LoopKC` - // supports non-vector multiples. - return MMKernel::View(A, 0, 0, A.Cols()); - } else { - // Always decompress. To reduce code size/compile time, we no longer - // support a separate F32 kernel; most A are already BF16. We also only - // have a single MMStorage. - HWY_ASSERT(options.cluster_idx == 0); - const StridedViewBF A_view = env.storage.A(A.Extents()); - DecompressA(A, A_view, autotune, env, options); - return A_view; - } - } }; // Defines several variants of the outer M/N/K loops (see `MMOrder`). @@ -885,7 +887,7 @@ class MMLoops { RowPtrs C_rows, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, - MMImpl::Worker(args.env, args.options.cluster_idx), zone); + args.env.ctx.Worker(args.options.cluster_idx), zone); DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { @@ -931,7 +933,7 @@ class MMLoops { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = - MMKernel::DecompressB(B, row_b, range_K, B_storage_view); + MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), args, C_rows); } @@ -1026,8 +1028,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); const size_t cluster_idx = options.cluster_idx; HWY_DASSERT(cluster_idx < env.row_ptrs.size()); - PROFILER_ZONE3(env.ctx.profiler, - cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone); + PROFILER_ZONE3(env.ctx.profiler, env.ctx.Worker(cluster_idx), zone); RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); @@ -1041,7 +1042,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = - MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options); + MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); MMAutoTune& tuner = per_key.autotune; if (HWY_LIKELY(tuner.Best())) { diff --git a/ops/matmul.h b/ops/matmul.h index 915970c..a85d192 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -116,7 +116,7 @@ struct MMParallelNone { size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx, const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); func(range_n, worker); } @@ -125,7 +125,7 @@ struct MMParallelNone { const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, size_t cluster_idx, const Func& func) const { - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { const IndexRange range_mc = ranges_mc.Range(i); @@ -139,7 +139,7 @@ struct MMParallelNone { template void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, size_t cluster_idx, const Func& func) const { - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { func(row_a, worker); } @@ -154,7 +154,7 @@ struct MMParallelWithinCluster { const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); const IndexRangePartition worker_ranges = StaticPartition( range_n, cluster.NumWorkers() * inner_tasks, n_multiple); @@ -171,7 +171,7 @@ struct MMParallelWithinCluster { const Func& func) const { const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { @@ -192,7 +192,7 @@ struct MMParallelWithinCluster { size_t cluster_idx, const Func& func) const { const size_t pkg_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); cluster.Run( range_mc.begin(), range_mc.end(), @@ -233,8 +233,7 @@ struct MMParallelHierarchical { n_ranges, all_clusters, [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); - const size_t cluster_base = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t cluster_base = ctx.Worker(cluster_idx); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( n_range, cluster.NumWorkers() * inner_tasks, n_multiple); @@ -284,8 +283,7 @@ struct MMParallelHierarchical { ParallelizeOneRange( ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { - const size_t cluster_base = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t cluster_base = ctx.Worker(cluster_idx); hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange(ranges_mc, cluster, [&](const IndexRange& range_mc, size_t worker) { diff --git a/util/threading_context.h b/util/threading_context.h index 41d0811..ac42526 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -97,6 +97,13 @@ class ThreadingArgs : public ArgsBase { struct ThreadingContext { explicit ThreadingContext(const ThreadingArgs& args); + // Returns a worker index compatible with those from `ParallelFor`, assuming + // the current thread is running on one thread per cluster, which happens + // when `ParallelismStrategy` is `kAcrossClusters`. + size_t Worker(size_t cluster_idx) const { + return cluster_idx * pools.MaxWorkersPerCluster(); + } + // Singleton; pass around a reference to reduce overhead. hwy::Profiler& profiler; @@ -158,7 +165,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, switch (parallelism) { case ParallelismStrategy::kNone: { - const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); for (size_t task = 0; task < num_tasks; ++task) { func(task, worker); } @@ -173,7 +180,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, case ParallelismStrategy::kWithinCluster: { // Ensure the worker argument is unique across clusters, because it is // used for TLS indexing for example in profiler.h. - const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t base = ctx.Worker(cluster_idx); return ctx.pools.Cluster(pkg_idx, cluster_idx) .Run(0, num_tasks, [&](uint64_t task, size_t worker) { func(task, base + worker); @@ -193,8 +200,7 @@ void ParallelFor(ParallelismStrategy parallelism, size_t num_tasks, return ctx.pools.AllClusters(pkg_idx).Run( 0, num_tasks, [&](uint64_t task, size_t cluster_idx) { - const size_t worker = - cluster_idx * ctx.pools.MaxWorkersPerCluster(); + const size_t worker = ctx.Worker(cluster_idx); func(task, worker); }); }