From 31c09cca4cdc0aaea8acc98175bb7b6b5154a852 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 28 Aug 2025 08:55:15 -0700 Subject: [PATCH] f32 LoopKC: 1.37x(M=512), 1.19(M=128) single-K F32,BF16 matmul speedup on SKX Add a special case for A=F32,B=BF16, used when there is no native bf16 dot product. dot-inl: ensure bf16,f32 and f32,bf16 both get promoted to float before f64 summation matmul.cc: update autotuning to reflect actual A size matmul_test: add all combinations of bf16/f32, report all results, not just first difference, check non-vector-aligned K PiperOrigin-RevId: 800487817 --- compression/test_util-inl.h | 17 +- ops/dot-inl.h | 61 +++- ops/matmul-inl.h | 633 ++++++++++++++++++++++++++---------- ops/matmul.cc | 24 +- ops/matmul.h | 8 +- ops/matmul_test.cc | 51 ++- 6 files changed, 594 insertions(+), 200 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 7c4f854..207b225 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -85,15 +85,21 @@ MatStorageT GenerateMat(const Extents2D& extents, row[c] = f; } Compress(raw.Row(r), raw.Cols(), ws.tls[thread], - MakeSpan(compressed.Row(r), compressed.Cols()), + MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); + + // MatMul requires that A's padding be zero-initialized. + hwy::ZeroBytes( + compressed.Row(r) + extents.cols, + (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); compressed.SetScale(0.6f); // Arbitrary value, different from 1. return compressed; } -// Same, but `extents` describes the transposed matrix. +// Same, but `extents` describes the transposed matrix and the computation of +// `f` swaps `r` and `c`. template MatStorageT GenerateTransposedMat(const Extents2D extents, const Allocator& allocator, @@ -112,8 +118,13 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, row[c] = f; } Compress(raw.Row(r), raw.Cols(), ws.tls[thread], - MakeSpan(compressed.Row(r), compressed.Cols()), + MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); + + // MatMul requires that B's padding be zero-initialized. + hwy::ZeroBytes( + compressed.Row(r) + extents.cols, + (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); // Arbitrary value, different from 1, must match `GenerateMat`. diff --git a/ops/dot-inl.h b/ops/dot-inl.h index 48aaae9..dae2106 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -157,15 +157,16 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) { // promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA. struct DotKernelDouble { // Only `CompressTraits` can `Decompress2` to `double`, so both have - // to be `float` in order to have `Raw = double`. Note that if either type is - // smaller than `float`, we may demote the other type from `float` to `BF16`. + // to be `float` in order to have `Raw = double`. To avoid loss of accuracy, + // if either is float, we decompress both to float, otherwise `BF16`. template - using Raw = hwy::If() && IsF32(), double, BF16>; + using Raw = hwy::If() && IsF32(), double, + hwy::If() || IsF32(), float, BF16>>; using State = double; // Raw = double template , HWY_IF_F64_D(DRaw)> - HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2, + HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, const VR w3, const VR v0, const VR v1, const VR v2, const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const { @@ -175,6 +176,41 @@ struct DotKernelDouble { sum3 = hn::MulAdd(w3, v3, sum3); } + // Raw = float + template , HWY_IF_F32_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2, + const VR w3, const VR v0, const VR v1, const VR v2, + const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3, + VS&, VS&, VS&, VS&) const { + const hn::Repartition dd; + using VD = hn::Vec; + VD w0d = hn::PromoteLowerTo(dd, w0); + VD w1d = hn::PromoteLowerTo(dd, w1); + VD w2d = hn::PromoteLowerTo(dd, w2); + VD w3d = hn::PromoteLowerTo(dd, w3); + VD v0d = hn::PromoteLowerTo(dd, v0); + VD v1d = hn::PromoteLowerTo(dd, v1); + VD v2d = hn::PromoteLowerTo(dd, v2); + VD v3d = hn::PromoteLowerTo(dd, v3); + sum0 = hn::MulAdd(w0d, v0d, sum0); + sum1 = hn::MulAdd(w1d, v1d, sum1); + sum2 = hn::MulAdd(w2d, v2d, sum2); + sum3 = hn::MulAdd(w3d, v3d, sum3); + w0d = hn::PromoteUpperTo(dd, w0); + w1d = hn::PromoteUpperTo(dd, w1); + w2d = hn::PromoteUpperTo(dd, w2); + w3d = hn::PromoteUpperTo(dd, w3); + v0d = hn::PromoteUpperTo(dd, v0); + v1d = hn::PromoteUpperTo(dd, v1); + v2d = hn::PromoteUpperTo(dd, v2); + v3d = hn::PromoteUpperTo(dd, v3); + sum0 = hn::MulAdd(w0d, v0d, sum0); + sum1 = hn::MulAdd(w1d, v1d, sum1); + sum2 = hn::MulAdd(w2d, v2d, sum2); + sum3 = hn::MulAdd(w3d, v3d, sum3); + } + // Raw = BF16 template , HWY_IF_BF16_D(DRaw), class DS = hn::Repartition, class VS = hn::Vec> @@ -217,11 +253,26 @@ struct DotKernelDouble { // Raw = double template , HWY_IF_F64_D(DRaw)> - HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0, + HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VR& sum0, VR&) const { sum0 = hn::MulAdd(w0, v0, sum0); } + // Raw = float + template , HWY_IF_F32_D(DRaw), + class DS = hn::Repartition, class VS = hn::Vec> + HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VS& sum0, + VS&) const { + const hn::Repartition dd; + using VD = hn::Vec; + VD w0d = hn::PromoteLowerTo(dd, w0); + VD v0d = hn::PromoteLowerTo(dd, v0); + sum0 = hn::MulAdd(w0d, v0d, sum0); + w0d = hn::PromoteUpperTo(dd, w0); + v0d = hn::PromoteUpperTo(dd, v0); + sum0 = hn::MulAdd(w0d, v0d, sum0); + } + // Raw = BF16 template , HWY_IF_BF16_D(DRaw), class DS = hn::Repartition, class VS = hn::Vec> diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 4741759..9f279cb 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -113,7 +113,7 @@ class MMStoreHorizontalSumsIntoC { const size_t row_c, const size_t col_c, const MMArgs& args, RowPtrs C_rows) const { HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; - const size_t N = hn::Lanes(df); + HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing // log(N) operations for vectors of length N. Because `kNR` == 4, we // instead use `StoreInterleaved4` for a vector length-agnostic @@ -230,7 +230,7 @@ class MMAddHorizontalSumsIntoPartial { const hn::Repartition dd; HWY_ALIGN double buf[16 * hn::MaxLanes(dd)]; using VD = hn::Vec; - const size_t ND = hn::Lanes(dd); + HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); VD C00 = SumOfPromotedPairs(dd, F00); VD C01 = SumOfPromotedPairs(dd, F01); VD C02 = SumOfPromotedPairs(dd, F02); @@ -351,8 +351,8 @@ class MMKernel { // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template - static HWY_INLINE void A2C0(const StridedViewBF& A_view, + template + static HWY_INLINE void A2C0(const StridedView A_view, const bool A_padded, const StridedViewBF& B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, @@ -365,8 +365,8 @@ class MMKernel { // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } return; } @@ -375,13 +375,13 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } return; } @@ -389,18 +389,20 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, + args, C_rows); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); imc += 1; } HWY_DASSERT(imc == mc); @@ -423,9 +425,10 @@ class MMKernel { // `MMAddHorizontalSumsIntoPartial`. template , class DF = hn::Repartition, class VF = hn::Vec> - static HWY_INLINE void ElementwiseMulAcc(DBF dbf, VBF a, VBF b0, VBF b1, - VBF b2, VBF b3, VF& C0, VF& C1, - VF& C2, VF& C3) { + static HWY_INLINE void ElementwiseMulAccNativeBF(DBF dbf, VBF a, VBF b0, + VBF b1, VBF b2, VBF b3, + VF& C0, VF& C1, VF& C2, + VF& C3) { // This handles a single row of A, so the horizontal sums of `C0..3` are the // (partial) dot products for 4 consecutive values in one row of C. static_assert(kNR == 4); @@ -443,16 +446,17 @@ class MMKernel { HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); } - // Like `ElementwiseMulAcc`, but splits BF16 inputs into odd and even f32 - // for use with FMA. Also handles two rows at a time to hide the FMA latency - // (we assume 4 cycles and dual-issue) before writing `C00` again. + // Like `ElementwiseMulAccNativeBF`, but splits BF16 inputs into odd and even + // f32 for use with FMA. Also handles two rows at a time to hide the FMA + // latency (we assume 4 cycles and dual-issue) before writing `C00` again. template , class DF = hn::Repartition, class VF = hn::Vec> - static HWY_INLINE void ElementwiseMulAcc2(DBF dbf, VBF a0, VBF a1, VF b0o, - VF b0e, VF b1o, VF b1e, VF b2o, - VF b2e, VF b3o, VF b3e, VF& C00, - VF& C01, VF& C02, VF& C03, VF& C10, - VF& C11, VF& C12, VF& C13) { + static HWY_INLINE void ElementwiseMulAccEmuBF(DBF dbf, VBF a0, VBF a1, VF b0o, + VF b0e, VF b1o, VF b1e, VF b2o, + VF b2e, VF b3o, VF b3e, VF& C00, + VF& C01, VF& C02, VF& C03, + VF& C10, VF& C11, VF& C12, + VF& C13) { const DF df; HWY_DASSERT(!HWY_NATIVE_DOT_BF16); // Avoid `ReorderWidenMulAccumulate` because it requires extra adds for @@ -491,20 +495,36 @@ class MMKernel { } } - // Innermost loop over `kc` columns (typically 1024-4096) in steps of one - // vector, for `kRowsAC` rows of `A_view` from range_mc-relative `imc` and - // `B_view` from row 0 (both at column 0). Updates a `kRowsAC x kNR` tile - // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be - // BF16 so we can load directly without `Decompress2`, which is expensive for - // NUQ and requires 2x unrolling, which requires more loads. - template - static HWY_INLINE void LoopKC(const StridedViewBF& A_view, + // For A=F32, B=BF16 without native BF16 dot product: one lane-crossing + // promotion is likely cheaper than AND+SHIFT for promoting odd/even BF. + // Caller already promoted B, so all inputs are F32. + template , HWY_IF_F32_D(DF)> + static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2, + VF b3, VF& C0, VF& C1, VF& C2, + VF& C3) { + HWY_DASSERT(!HWY_NATIVE_DOT_BF16); + C0 = hn::MulAdd(a, b0, C0); + C1 = hn::MulAdd(a, b1, C1); + C2 = hn::MulAdd(a, b2, C2); + C3 = hn::MulAdd(a, b3, C3); + } + + // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a + // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` + // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). + // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. + // `B` is BF16, `A` and `C` can be F32 or BF16. + template + static HWY_INLINE void LoopKC(const StridedView A_view, + const bool A_padded, const StridedViewBF& B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; - const size_t NBF = hn::Lanes(dbf); + + HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag()); + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_DASSERT(kRowsAC <= kMaxMR); HWY_DASSERT(col_c % kNR == 0); @@ -512,30 +532,36 @@ class MMKernel { // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. static_assert(kNR == 4); - const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0); - const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; - const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; - const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; + const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0); + const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; + const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; + const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; const BF16* HWY_RESTRICT br0 = B_view.Row(0); const BF16* HWY_RESTRICT br1 = B_view.Row(1); const BF16* HWY_RESTRICT br2 = B_view.Row(2); const BF16* HWY_RESTRICT br3 = B_view.Row(3); - // Ensure `A` and `B` were zero-padded by `DecompressAndZeroPad`. + // Ensure `A` and `B` were zero-padded. if constexpr (HWY_IS_DEBUG_BUILD) { + // Only check if `A` is padded, i.e. not packed. + if (A_padded) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) { + { + HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); + } + if constexpr (kRowsAC > 1) { + HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); + } + if constexpr (kRowsAC > 2) { + HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); + } + if constexpr (kRowsAC > 3) { + HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); + } + } + } + // B is unconditionally zero-padded by `DecompressAndZeroPad`. for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - { - HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); - } - if constexpr (kRowsAC > 1) { - HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); - } - if constexpr (kRowsAC > 2) { - HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); - } - if constexpr (kRowsAC > 3) { - HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); - } HWY_DASSERT(hwy::ConvertScalarTo(br0[i]) == 0.0f); HWY_DASSERT(hwy::ConvertScalarTo(br1[i]) == 0.0f); HWY_DASSERT(hwy::ConvertScalarTo(br2[i]) == 0.0f); @@ -553,60 +579,287 @@ class MMKernel { C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df), C33 = hn::Zero(df); - HWY_UNROLL(1) - for (size_t ikc = 0; ikc < kc; ikc += NBF) { + size_t ikc = 0; + // The loop step is always NBF: for non-native BF16 with TA=F32, this + // entails 2x unrolling, which helps a little. + const HWY_LANES_CONSTEXPR size_t kc_step = NBF; + // If A is packed (not padded), we have to check for remainders. Otherwise, + // we only run the main loop because A's padding is zero-initialized by + // `ZeroInit` or weights.cc. + const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc; + if (kc_end >= kc_step) { + HWY_UNROLL(1) + for (; ikc <= kc_end - kc_step; ikc += kc_step) { + if constexpr (HWY_NATIVE_DOT_BF16) { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + + // Should only get here if `A` is BF16, otherwise `DecompressA` would + // convert to BF16 and `A_view` points to that. + HWY_DASSERT(IsBF16()); + + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VBF a1 = hn::Load(dbf, ar1 + ikc); + ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12, + C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VBF a3 = hn::Load(dbf, ar3 + ikc); + ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32, + C33); + } + } else { // !HWY_NATIVE_DOT_BF16 + if constexpr (IsBF16()) { + // When both are BF16, it is better to load promote odd/even, + // because lane-crossing promotion for both might be bottlenecked on + // shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } + + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C00, C01, C02, C03, C10, C11, + C12, C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C20, C21, C22, C23, C30, C31, + C32, C33); + } + } else { // IsF32(): promote BF to 2xF32, F32*F32. + // Full-vector loads are a bit faster on SKX than half + PromoteTo. + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + const VF b00 = hn::PromoteLowerTo(df, b0); + const VF b10 = hn::PromoteLowerTo(df, b1); + const VF b20 = hn::PromoteLowerTo(df, b2); + const VF b30 = hn::PromoteLowerTo(df, b3); + const VF b01 = hn::PromoteUpperTo(df, b0); + const VF b11 = hn::PromoteUpperTo(df, b1); + const VF b21 = hn::PromoteUpperTo(df, b2); + const VF b31 = hn::PromoteUpperTo(df, b3); + + { + const VF a00 = hn::Load(df, ar0 + ikc); + ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a10 = hn::Load(df, ar1 + ikc); + ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, + C13); + } + + // C00 is ready again. On SKX, this interleaved unrolling is faster + // than consuming all `b*1` at the end of the loop. + { + const VF a01 = hn::Load(df, ar0 + ikc + NA); + ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a11 = hn::Load(df, ar1 + ikc + NA); + ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, + C13); + } + + if constexpr (kRowsAC > 2) { + const VF a20 = hn::Load(df, ar2 + ikc); + ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a30 = hn::Load(df, ar3 + ikc); + ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, + C33); + } + + if constexpr (kRowsAC > 2) { + const VF a21 = hn::Load(df, ar2 + ikc + NA); + ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a31 = hn::Load(df, ar3 + ikc + NA); + ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, + C33); + } + } + } + } + } + + // We want the number of actual valid kc, but we may already be beyond `kc`. + const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc; + HWY_DASSERT(remaining_kc < kc_step); + HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0)); + // Last iteration: B is padded but A is not; guard its loads. + if (HWY_UNLIKELY(remaining_kc != 0)) { if constexpr (HWY_NATIVE_DOT_BF16) { const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b3 = hn::Load(dbf, br3 + ikc); + + // Should only get here if `A` is BF16, otherwise `DecompressA` would + // convert to BF16 and `A_view` points to that. + HWY_DASSERT(IsBF16()); + { - const VBF a0 = hn::Load(dbf, ar0 + ikc); - ElementwiseMulAcc(dbf, a0, b0, b1, b2, b3, C00, C01, C02, C03); + const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, + C03); } if constexpr (kRowsAC > 1) { - const VBF a1 = hn::Load(dbf, ar1 + ikc); - ElementwiseMulAcc(dbf, a1, b0, b1, b2, b3, C10, C11, C12, C13); + const VBF a1 = hn::LoadN(dbf, ar1 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12, + C13); } if constexpr (kRowsAC > 2) { - const VBF a2 = hn::Load(dbf, ar2 + ikc); - ElementwiseMulAcc(dbf, a2, b0, b1, b2, b3, C20, C21, C22, C23); + const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22, + C23); } if constexpr (kRowsAC > 3) { - const VBF a3 = hn::Load(dbf, ar3 + ikc); - ElementwiseMulAcc(dbf, a3, b0, b1, b2, b3, C30, C31, C32, C33); + const VBF a3 = hn::LoadN(dbf, ar3 + ikc, remaining_kc); + ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32, + C33); } - } else { - VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; - { + } else { // !HWY_NATIVE_DOT_BF16 + if constexpr (IsBF16()) { + // When both are BF16, it is better to load promote odd/even, because + // lane-crossing promotion for both might be bottlenecked on shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { + const VBF b0 = hn::Load(dbf, br0 + ikc); + const VBF b1 = hn::Load(dbf, br1 + ikc); + const VBF b2 = hn::Load(dbf, br2 + ikc); + const VBF b3 = hn::Load(dbf, br3 + ikc); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } + + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); + const VBF a1 = + kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C00, C01, C02, C03, C10, C11, C12, + C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); + const VBF a3 = + kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C20, C21, C22, C23, C30, C31, C32, + C33); + } + } else { // IsF32(): promote half-B to F32, F32*F32. const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b3 = hn::Load(dbf, br3 + ikc); - b0e = hn::PromoteEvenTo(df, b0); - b1e = hn::PromoteEvenTo(df, b1); - b2e = hn::PromoteEvenTo(df, b2); - b3e = hn::PromoteEvenTo(df, b3); - b0o = FastPromoteOddTo(df, b0); - b1o = FastPromoteOddTo(df, b1); - b2o = FastPromoteOddTo(df, b2); - b3o = FastPromoteOddTo(df, b3); - } + const VF b00 = hn::PromoteLowerTo(df, b0); + const VF b10 = hn::PromoteLowerTo(df, b1); + const VF b20 = hn::PromoteLowerTo(df, b2); + const VF b30 = hn::PromoteLowerTo(df, b3); + const VF b01 = hn::PromoteUpperTo(df, b0); + const VF b11 = hn::PromoteUpperTo(df, b1); + const VF b21 = hn::PromoteUpperTo(df, b2); + const VF b31 = hn::PromoteUpperTo(df, b3); - { - const VBF a0 = hn::Load(dbf, ar0 + ikc); - const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; - ElementwiseMulAcc2(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o, - b3e, C00, C01, C02, C03, C10, C11, C12, C13); - } - if constexpr (kRowsAC > 2) { - const VBF a2 = hn::Load(dbf, ar2 + ikc); - const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; - ElementwiseMulAcc2(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o, - b3e, C20, C21, C22, C23, C30, C31, C32, C33); + const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA; + + { + const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, + C13); + } + + // C00 is ready again. On SKX, this interleaved unrolling is faster + // than consuming all `b*1` at the end of the loop. + { + const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, + C03); + } + if constexpr (kRowsAC > 1) { + const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, + C13); + } + + if constexpr (kRowsAC > 2) { + const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc); + ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, + C33); + } + + if constexpr (kRowsAC > 2) { + const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, + C23); + } + if constexpr (kRowsAC > 3) { + const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2); + ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, + C33); + } } } - } + } // remaining_kc != 0 // This is a substantial fraction (about 1/3) of the total time, but is // called frequently, so do not add a profiler zone. @@ -678,7 +931,7 @@ class MMScaleDemoteAdd { const hn::Rebind dc; using VD = hn::Vec; using VF = hn::Vec; - const size_t ND = hn::Lanes(dd); + HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); @@ -796,7 +1049,7 @@ class MMScaleDemoteAdd { const hn::Rebind dc; using VD = hn::Vec; using VF = hn::Vec; - const size_t ND = hn::Lanes(dd); + HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; @@ -858,41 +1111,51 @@ class MMScaleDemoteAdd { // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. // Its member variables avoid long argument lists in Do*(). class MMPerPackage { - public: + // Decompression is only required for F32 A and native BF16 dot products. + // If A is already BF16, we can use a view. Padding is not required + // because `LoopKC` can handle non-vector multiples. `LoopKC` also contains + // a special case for F32 `A` and non-native BF16 dot products. template - MMPerPackage(const MatPtrT& A, const MMArgs& args, const MMConfig& config, + static constexpr bool WantDecompressA() { + return HWY_NATIVE_DOT_BF16 && IsF32(); + } + + public: + MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, size_t pkg_idx, const IndexRange& range_np) : args_(args), pkg_idx_(pkg_idx), - // May be overwritten with a view of A, if already BF16. - A_(args_.env->storage.A(pkg_idx, A.Extents())), range_np_(range_np), mr_(config.MR()), - ranges_mc_(config.RangesOfMC(A.Rows())), - ranges_kc_(config.RangesOfKC(A.Cols())), + ranges_mc_(config.RangesOfMC(A.rows)), + ranges_kc_(config.RangesOfKC(A.cols)), ranges_nc_(config.RangesOfNC(range_np)), order_(config.Order()), inner_tasks_(config.InnerTasks()), out_(config.Out()), - line_bytes_(args.env->ctx.allocator.LineBytes()) { - A_ = DecompressA(A); + line_bytes_(args.env->ctx.allocator.LineBytes()) {} + + // The size of `A` that will actually be used, for purposes of choosing the + // autotuning candidates. Keep in sync with the `operator()` logic below. + template + static constexpr size_t ABytes() { + return WantDecompressA() ? sizeof(BF16) : sizeof(TA); } - // B is decompressed several call layers lower, but not all member functions - // depend on TB, so pass it as an argument instead of templating the class. - template - HWY_NOINLINE void operator()(const MatPtrT& B, RowPtrs C_rows) const { - switch (order_) { - case MMOrder::kNT: - return DoNT(B, C_rows); - case MMOrder::kNT_K: - return DoNT_K(B, C_rows); - case MMOrder::kNT_MT: - return DoNT_MT(B, C_rows); - case MMOrder::kNT_MT_K: - return DoNT_MT_K(B, C_rows); - default: - HWY_UNREACHABLE; + // B and maybe A are decompressed several call layers lower, but not all + // member functions depend on TA/TB, so pass them as an argument instead of + // templating the class. + template + HWY_NOINLINE void operator()(const MatPtrT& A, const MatPtrT& B, + RowPtrs C_rows) const { + if constexpr (WantDecompressA()) { + const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); + DecompressA(A, A_view); + constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. + DispatchOrder(A_view, A_padded, B, C_rows); + } else { + const bool A_padded = HasPadding(A); + DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows); } } @@ -909,16 +1172,57 @@ class MMPerPackage { return HWY_MAX(kNR, line_bytes_ / sizeof_TC); } + // Use instead of `MatPtr::IsPacked` because that returns true for single + // rows, but we want to know whether there is padding. + static bool HasPadding(const MatPtr& mat) { + return mat.Stride() > mat.Cols(); + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A`` + // and `B` are const, but StridedView is also used for non-const `partial`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag()); + (void)N; + // If `AB` is padded, then `LoopKC` expects the view is either a vector + // multiple, or all columns and thus also padded. + HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols())); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. + template + HWY_INLINE void DispatchOrder(const StridedView A, const bool A_padded, + const MatPtrT& B, + RowPtrs C_rows) const { + switch (order_) { + case MMOrder::kNT: + return DoNT(A, A_padded, B, C_rows); + case MMOrder::kNT_K: + return DoNT_K(A, A_padded, B, C_rows); + case MMOrder::kNT_MT: + return DoNT_MT(A, A_padded, B, C_rows); + case MMOrder::kNT_MT_K: + return DoNT_MT_K(A, A_padded, B, C_rows); + default: + HWY_UNREACHABLE; + } + } + // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); - const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K); + const StridedView A_view = A.View(range_M.begin(), 0, K); const size_t B_stride = Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); @@ -936,8 +1240,8 @@ class MMPerPackage { row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), - args_, C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K, + MMSetC(), args_, C_rows); } }); @@ -945,8 +1249,9 @@ class MMPerPackage { } // Single M range, parallel N, sequential K. Fills all of partial. - template - HWY_INLINE void DoNT_K(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_K(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); const IndexRange& range_mc = ranges_mc_.Range(0); @@ -958,8 +1263,8 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedViewBF& A_view = - A_.View(range_mc.begin(), range_kc.begin(), kc); + const StridedView A_view = + A.View(range_mc.begin(), range_kc.begin(), kc); const StridedViewBF B_storage_view( B_storage, kc, Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); @@ -967,8 +1272,8 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, + out_tag, args_, C_rows); } }; @@ -1013,8 +1318,9 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); @@ -1031,7 +1337,7 @@ class MMPerPackage { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args_); - const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K); + const StridedView A_view = A.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS const StridedViewBF B_storage_view(B_storage, K, B_stride); @@ -1039,8 +1345,8 @@ class MMPerPackage { row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), - args_, C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, + MMSetC(), args_, C_rows); } }); @@ -1049,8 +1355,9 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. - template - HWY_INLINE void DoNT_MT_K(const MatPtrT& B, RowPtrs C_rows) const { + template + HWY_INLINE void DoNT_MT_K(const StridedView A, const bool A_padded, + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); static const auto fill_zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC"); @@ -1067,14 +1374,14 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const StridedViewBF& A_view = - A_.View(range_mc.begin(), range_kc.begin(), kc); + const StridedView A_view = + A.View(range_mc.begin(), range_kc.begin(), kc); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); + MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, + out_tag, args_, C_rows); } }; // loop_nc args_.env->parallel.ForRangesMC_NC( @@ -1107,17 +1414,16 @@ class MMPerPackage { }); } - // Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a - // seekable type (i.e., not NUQ) so we can use pointer arithmetic. - template - HWY_NOINLINE void DoDecompressA(const MatPtrT& A, MMParA par_a) const { + // Decompresses all `M x K` from `A` into padded BF16 `A_view`. + HWY_NOINLINE void DoDecompressA(const MatPtrT& A, + const StridedViewBF A_view, + MMParA par_a) const { const IndexRange all_M(0, A.Rows()); const IndexRange all_K(0, A.Cols()); - HWY_DASSERT(all_K.Num() == A_.Cols()); + HWY_DASSERT(all_K.Num() == A_view.Cols()); const hn::ScalableTag dbf; const size_t NBF = hn::Lanes(dbf); - static_assert(hwy::IsSameEither(), "Can seek"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA"); @@ -1133,8 +1439,9 @@ class MMPerPackage { // otherwise `DecompressAndZeroPad` overwrites neighbors. HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols()); for (size_t row_a : range_M) { - const PackedSpan from = MakeSpan(A.Row(row_a) + col0, cols); - BF16* HWY_RESTRICT to = A_.Row(row_a) + col0; + const PackedSpan from = + MakeSpan(A.Row(row_a) + col0, cols); + BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0; DecompressAndZeroPad(dbf, from, 0, to, cols); // Verify that we zero-padded. if constexpr (HWY_IS_DEBUG_BUILD) { @@ -1175,23 +1482,12 @@ class MMPerPackage { } // Autotuning wrapper for `DoDecompressA`. - template - HWY_INLINE StridedViewBF DecompressA(const MatPtrT& A) const { + HWY_INLINE void DecompressA(const MatPtrT& A, + const StridedViewBF A_view) const { MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; - // If already BF16, maybe return a view: - if constexpr (hwy::IsSame()) { - // Only if vector multiple and padded (see `DoDecompressA`). - const size_t NBF = hn::Lanes(hn::ScalableTag()); - if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) { - // Const, but cast because StridedView is also used for `partial` which - // is non-const. - return StridedViewBF(const_cast(A.Row(0)), A.Cols(), A.Stride()); - } - } if (HWY_LIKELY(autotune.Best())) { - DoDecompressA(A, *autotune.Best()); - return A_; + return DoDecompressA(A, A_view, *autotune.Best()); } // First call: generate candidates. @@ -1204,7 +1500,7 @@ class MMPerPackage { const MMParA& par_a = autotune.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, par_a); + DoDecompressA(A, A_view, par_a); const uint64_t t1 = args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); @@ -1213,7 +1509,6 @@ class MMPerPackage { static_cast(min_elapsed) / hwy::platform::InvariantTicksPerSecond() * 1E6); } - return A_; } // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, @@ -1223,12 +1518,17 @@ class MMPerPackage { HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, const IndexRange& range_kc, const StridedViewBF& B_view) const { + const hn::ScalableTag dbf; + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); + + // View() is safe if vector multiple, or padded: for the latter, `ZeroInit` + // and weights.cc zero-initialize the padding. if constexpr (hwy::IsSame()) { - return StridedViewBF(const_cast(B.Row(row_b)) + range_kc.begin(), - range_kc.Num(), B.Stride()); + if (B.Cols() % NBF == 0 || HasPadding(B)) { + return View(B, row_b, range_kc.begin(), range_kc.Num()); + } } - const hn::ScalableTag dbf; const PackedSpan B_span = B.PaddedSpan(); const size_t kc = range_kc.Num(); @@ -1240,7 +1540,7 @@ class MMPerPackage { DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); // Verify that we zero-padded. if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, hn::Lanes(dbf)); ++i) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); } } @@ -1250,7 +1550,6 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; - StridedViewBF A_; // view into A or pkg_A_, both of which are padded. const IndexRange range_np_; // From MMConfig: @@ -1293,13 +1592,14 @@ struct MMImpl { MMZone mm_zone; mm_zone.MaybeEnter(pkg_idx, zone, args); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, + C_rows); }); } else { const size_t pkg_idx = 0; HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); + MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows); } } }; @@ -1310,7 +1610,7 @@ struct MMImpl { // `K = B.Cols()`, which must match `A.Cols()`, is the number // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // are no other restrictions on shape, though performance is better when `M % 4 -// == 0` or `M <= 4`, and when A is padded (`!A.IsPacked()`). +// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()). // // NOTE: if A and/or B are BF16 and padded, the interval `[Cols(), // hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match @@ -1376,8 +1676,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N % kNR == 0); - tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, - kNR, per_key.ranges_np, env.print_config)); + tuner.SetCandidates( + MMCandidates(allocator, M, K, N, MMPerPackage::ABytes(), sizeof(TC), + kMaxMR, kNR, per_key.ranges_np, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index c51acbd..c9ddfb6 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -64,19 +64,21 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, class GenerateCandidates { public: GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, - size_t sizeof_TC, size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, bool print_config) + size_t sizeof_TA, size_t sizeof_TC, size_t max_mr, + size_t nr, const IndexRangePartition& ranges_np, + bool print_config) : allocator_(allocator), M_(M), K_(K), N_(N), + sizeof_TA_(sizeof_TA), sizeof_TC_(sizeof_TC), max_mr_(max_mr), nr_(nr), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line // is likely still in L1, but we expect K > 1000 and might as well round - // up to the line size. + // up to the line size. Use BF16, not sizeof_TA, because B is BF16. kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), nc_multiple_(allocator.StepBytes() / sizeof_TC), ranges_np_(ranges_np), @@ -176,8 +178,9 @@ class GenerateCandidates { // subtract the output and buf, and allow using more than the actual L1 // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. - const size_t bytes_ab = allocator_.L1Bytes() * 3; - const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); + const size_t bytes_ab = + allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream)); + const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); kc_max = RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); @@ -224,7 +227,7 @@ class GenerateCandidates { // packed B. We want `mc * kc` elements of A to fit in L2, alongside // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // partial. - const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes(); + const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes(); size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); HWY_DASSERT(mc_max != 0); @@ -359,6 +362,7 @@ class GenerateCandidates { const size_t M_; const size_t K_; const size_t N_; + const size_t sizeof_TA_; const size_t sizeof_TC_; const size_t max_mr_; @@ -376,12 +380,12 @@ class GenerateCandidates { // Facade to avoid exposing `GenerateCandidates` in the header. std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TA, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) { - return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, - ranges_np, print_config)(); + return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr, + nr, ranges_np, print_config)(); } // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote diff --git a/ops/matmul.h b/ops/matmul.h index de8ef8c..99290c1 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -281,7 +281,9 @@ class MMStorage { BindC(partial_storage_, parallel); } - // Returns per-package matrix view. + // Returns per-package matrix view. Converting A=F32 to BF16 up-front is + // faster than on-the-fly when native BF16 is available: it only happens once, + // not per B tile row, and the cache footprint is smaller. StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.cols <= kMaxK); @@ -475,8 +477,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TA, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 122012e..aadbc56 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -118,17 +118,26 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch); const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); + // Dot() uses double-precision summation. double tolerance = 12 * norm * eps_f32; - // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the - // tolerance there. - if (IsF32() && IsF32()) { - tolerance += 4 * max_abs * eps_bf16; + // If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to + // BF16, so add extra tolerance. + if (IsF32()) { + tolerance += 2 * max_abs * eps_bf16; } + if (tolerance > 500.0) { HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); } - const double max_rel = 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); + const double rel_tolerance = + 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); + double max_rel = 0.0; + size_t worst_r = 0; + size_t worst_c = 0; + double worst_actual = 0.0; + double worst_expected = 0.0; + size_t num_outside = 0; for (size_t r = 0; r < A.Rows(); r++) { const float* expected_row = c_slow_batch.Row(r); const float* actual_row = c_batch.Row(r); @@ -143,15 +152,24 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const double min = HWY_MIN(expected_value, actual_value); const double rel = max / HWY_MAX(min, 1E-6); if (rel > max_rel) { - hwy::Abort(__FILE__, line, - "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " - "tolerance %f rel %E max_rel %E\n", - r, c, expected_value, actual_value, norm, max_abs, - tolerance, rel, max_rel); + worst_expected = expected_value; + worst_actual = actual_value; + worst_r = r; + worst_c = c; + max_rel = rel; + ++num_outside; } } } } + + if (max_rel > rel_tolerance) { + hwy::Abort(__FILE__, line, + "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " + "tolerance %f rel %E max_rel %E num_outside %zu\n", + worst_r, worst_c, worst_expected, worst_actual, norm, max_abs, + tolerance, max_rel, rel_tolerance, num_outside); + } } // B is already transposed. @@ -188,9 +206,9 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, TC* HWY_RESTRICT C_row = C.Row(r); for (size_t c : cols_c) { const float add = add_row ? add_row[c] : 0.0f; - C_row[c] = hwy::ConvertScalarTo( - add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r), - A.Cols())); + const float dot = + Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols()); + C_row[c] = hwy::ConvertScalarTo(add + scale * dot); } } }); @@ -279,6 +297,9 @@ void TestTiny() { for (size_t K = 1; K <= 64; K *= 2) { for (size_t N = 4; N <= 64; N += max_packages * 4) { TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); } } } @@ -334,6 +355,10 @@ void TestAllMatMul() { TestMatMul(256, 256, 256, /*add=*/false, env, __LINE__); TestMatMul(256, 256, 256, /*add=*/true, env, __LINE__); + // Non-vector-multiple K. + TestMatMul(128, 258, 128, /*add=*/true, env, __LINE__); + TestMatMul(128, 258, 128, /*add=*/true, env, __LINE__); + // minimal non-square test. kColsARowsB must be at least 2 vectors. TestMatMul(35, 128, 32, /*add=*/false, env, __LINE__); TestMatMul(34, 128, 32, /*add=*/true, env, __LINE__);