From 0ae8646731e3c80c56035083118c0d7310be20da Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 29 Aug 2025 07:25:14 -0700 Subject: [PATCH] Fix remainder handling for Paligemma No longer attempt to skip the remainder handling because B might also be a non-padded view. PiperOrigin-RevId: 800890805 --- compression/test_util-inl.h | 10 -- gemma/weights.cc | 3 - ops/matmul-inl.h | 202 ++++++++++++++---------------------- ops/matmul.h | 1 + ops/matmul_test.cc | 1 + 5 files changed, 79 insertions(+), 138 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 207b225..e5b1fe0 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -87,11 +87,6 @@ MatStorageT GenerateMat(const Extents2D& extents, Compress(raw.Row(r), raw.Cols(), ws.tls[thread], MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); - - // MatMul requires that A's padding be zero-initialized. - hwy::ZeroBytes( - compressed.Row(r) + extents.cols, - (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); compressed.SetScale(0.6f); // Arbitrary value, different from 1. @@ -120,11 +115,6 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, Compress(raw.Row(r), raw.Cols(), ws.tls[thread], MakeSpan(compressed.Row(r), extents.cols), /*packed_ofs=*/0); - - // MatMul requires that B's padding be zero-initialized. - hwy::ZeroBytes( - compressed.Row(r) + extents.cols, - (compressed.Stride() - extents.cols) * compressed.ElementBytes()); }); // Arbitrary value, different from 1, must match `GenerateMat`. diff --git a/gemma/weights.cc b/gemma/weights.cc index 3425a60..ca1cebc 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -444,9 +444,6 @@ static std::vector MakeBatches( HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); } offset += file_bytes_per_row; - // Must zero-initialize the in-memory row padding, see MatMul. - hwy::ZeroBytes(row_bytes + file_bytes_per_row, - mem_stride_bytes - file_bytes_per_row); row_bytes += mem_stride_bytes; } HWY_ASSERT(offset == range.End()); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 29be665..b54ce05 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -216,7 +216,7 @@ class MMKernel { // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. template - static HWY_INLINE void A2C0(const StridedView A_view, const bool A_padded, + static HWY_INLINE void A2C0(const StridedView A_view, const StridedViewBF& B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, @@ -229,8 +229,8 @@ class MMKernel { // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } return; } @@ -239,13 +239,13 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } return; } @@ -253,20 +253,18 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, - args, C_rows); + LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); imc += 1; } HWY_DASSERT(imc == mc); @@ -380,7 +378,6 @@ class MMKernel { // `B` is BF16, `A` and `C` can be F32 or BF16. template static HWY_INLINE void LoopKC(const StridedView A_view, - const bool A_padded, const StridedViewBF& B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { @@ -405,33 +402,8 @@ class MMKernel { const BF16* HWY_RESTRICT br2 = B_view.Row(2); const BF16* HWY_RESTRICT br3 = B_view.Row(3); - // Ensure `A` and `B` were zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - // Only check if `A` is padded, i.e. not packed. - if (A_padded) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) { - { - HWY_DASSERT(hwy::ConvertScalarTo(ar0[i]) == 0.0f); - } - if constexpr (kRowsAC > 1) { - HWY_DASSERT(hwy::ConvertScalarTo(ar1[i]) == 0.0f); - } - if constexpr (kRowsAC > 2) { - HWY_DASSERT(hwy::ConvertScalarTo(ar2[i]) == 0.0f); - } - if constexpr (kRowsAC > 3) { - HWY_DASSERT(hwy::ConvertScalarTo(ar3[i]) == 0.0f); - } - } - } - // B is unconditionally zero-padded by `DecompressAndZeroPad`. - for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(br0[i]) == 0.0f); - HWY_DASSERT(hwy::ConvertScalarTo(br1[i]) == 0.0f); - HWY_DASSERT(hwy::ConvertScalarTo(br2[i]) == 0.0f); - HWY_DASSERT(hwy::ConvertScalarTo(br3[i]) == 0.0f); - } - } + // Neither A nor B are guaranteed to be zero-padded: they might be a view + // into the left half. // Accumulate into f32. const hn::Repartition df; @@ -447,18 +419,18 @@ class MMKernel { // The loop step is always NBF: for non-native BF16 with TA=F32, this // entails 2x unrolling, which helps a little. const HWY_LANES_CONSTEXPR size_t kc_step = NBF; - // If A is packed (not padded), we have to check for remainders. Otherwise, - // we only run the main loop because A's padding is zero-initialized by - // `ZeroInit` or weights.cc. - const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc; - if (kc_end >= kc_step) { + if (kc >= kc_step) { HWY_UNROLL(1) - for (; ikc <= kc_end - kc_step; ikc += kc_step) { + for (; ikc <= kc - kc_step; ikc += kc_step) { if constexpr (HWY_NATIVE_DOT_BF16) { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + // NOTE: matmul_test has packed B so that it can call Span. The test + // cases with non-vector-multiple K require unaligned loads here. + // However, in actual usage, we should always have padded and thus + // aligned A and B. + const VBF b0 = hn::LoadU(dbf, br0 + ikc); + const VBF b1 = hn::LoadU(dbf, br1 + ikc); + const VBF b2 = hn::LoadU(dbf, br2 + ikc); + const VBF b3 = hn::LoadU(dbf, br3 + ikc); // Should only get here if `A` is BF16, otherwise `DecompressA` would // convert to BF16 and `A_view` points to that. @@ -491,10 +463,10 @@ class MMKernel { // shuffles. VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadU(dbf, br0 + ikc); + const VBF b1 = hn::LoadU(dbf, br1 + ikc); + const VBF b2 = hn::LoadU(dbf, br2 + ikc); + const VBF b3 = hn::LoadU(dbf, br3 + ikc); b0e = hn::PromoteEvenTo(df, b0); b1e = hn::PromoteEvenTo(df, b1); b2e = hn::PromoteEvenTo(df, b2); @@ -523,10 +495,10 @@ class MMKernel { } } else { // IsF32(): promote BF to 2xF32, F32*F32. // Full-vector loads are a bit faster on SKX than half + PromoteTo. - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadU(dbf, br0 + ikc); + const VBF b1 = hn::LoadU(dbf, br1 + ikc); + const VBF b2 = hn::LoadU(dbf, br2 + ikc); + const VBF b3 = hn::LoadU(dbf, br3 + ikc); const VF b00 = hn::PromoteLowerTo(df, b0); const VF b10 = hn::PromoteLowerTo(df, b1); const VF b20 = hn::PromoteLowerTo(df, b2); @@ -586,17 +558,16 @@ class MMKernel { } } - // We want the number of actual valid kc, but we may already be beyond `kc`. - const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc; + // Always handle remainders: even though A and B are generally padded, we + // might have a view into the left half of A and/or B. + const size_t remaining_kc = kc - ikc; HWY_DASSERT(remaining_kc < kc_step); - HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0)); - // Last iteration: B is padded but A is not; guard its loads. if (HWY_UNLIKELY(remaining_kc != 0)) { if constexpr (HWY_NATIVE_DOT_BF16) { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); + const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); + const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); + const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); // Should only get here if `A` is BF16, otherwise `DecompressA` would // convert to BF16 and `A_view` points to that. @@ -628,10 +599,10 @@ class MMKernel { // lane-crossing promotion for both might be bottlenecked on shuffles. VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; { - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); + const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); + const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); + const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); b0e = hn::PromoteEvenTo(df, b0); b1e = hn::PromoteEvenTo(df, b1); b2e = hn::PromoteEvenTo(df, b2); @@ -661,10 +632,10 @@ class MMKernel { C33); } } else { // IsF32(): promote half-B to F32, F32*F32. - const VBF b0 = hn::Load(dbf, br0 + ikc); - const VBF b1 = hn::Load(dbf, br1 + ikc); - const VBF b2 = hn::Load(dbf, br2 + ikc); - const VBF b3 = hn::Load(dbf, br3 + ikc); + const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); + const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); + const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); + const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); const VF b00 = hn::PromoteLowerTo(df, b0); const VF b10 = hn::PromoteLowerTo(df, b1); const VF b20 = hn::PromoteLowerTo(df, b2); @@ -786,12 +757,9 @@ class MMPerPackage { if constexpr (WantDecompressA()) { const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); DecompressA(A, A_view); - constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. - DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows); + DispatchOrder(parallel_policy, A_view, B, C_rows); } else { - const bool A_padded = HasPadding(A); - DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B, - C_rows); + DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); } } @@ -808,42 +776,29 @@ class MMPerPackage { return HWY_MAX(kNR, line_bytes_ / sizeof_TC); } - // Use instead of `MatPtr::IsPacked` because that returns true for single - // rows, but we want to know whether there is padding. - static bool HasPadding(const MatPtr& mat) { - return mat.Stride() > mat.Cols(); - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A`` - // and `B` are const, but StridedView is also used for non-const `partial`. + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. template static StridedView View(const MatPtrT& AB, size_t r, size_t c, size_t cols) { HWY_DASSERT(c < AB.Cols()); HWY_DASSERT(cols <= AB.Cols() - c); - HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag()); - (void)N; - // If `AB` is padded, then `LoopKC` expects the view is either a vector - // multiple, or all columns and thus also padded. - HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols())); return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); } // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. template HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, - const StridedView A, const bool A_padded, - const MatPtrT& B, + const StridedView A, const MatPtrT& B, RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(parallel_policy, A, A_padded, B, C_rows); + return DoNT(parallel_policy, A, B, C_rows); case MMOrder::kNT_K: - return DoNT_K(parallel_policy, A, A_padded, B, C_rows); + return DoNT_K(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(parallel_policy, A, A_padded, B, C_rows); + return DoNT_MT(parallel_policy, A, B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(parallel_policy, A, A_padded, B, C_rows); + return DoNT_MT_K(parallel_policy, A, B, C_rows); default: HWY_UNREACHABLE; } @@ -852,8 +807,7 @@ class MMPerPackage { // Single M and K ranges, parallel N. Fills all of C directly. template HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -878,8 +832,8 @@ class MMPerPackage { row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K, - MMSetC(), args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), + args_, C_rows); } }); } @@ -887,8 +841,7 @@ class MMPerPackage { // Single M range, parallel N, sequential K. Sets C, then accumulates. template HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(ranges_mc_.NumTasks() == 1); const IndexRange& range_mc = ranges_mc_.Range(0); @@ -909,8 +862,8 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, - out_tag, args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, + C_rows); } }; @@ -937,8 +890,7 @@ class MMPerPackage { // Fills `mc x nc` sections of C directly, in parallel. template HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); @@ -963,8 +915,8 @@ class MMPerPackage { row_b += kNR) { const StridedViewBF B_view = DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, - MMSetC(), args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), + args_, C_rows); } }); } @@ -973,8 +925,7 @@ class MMPerPackage { // Fills `mc x nc` sections of `partial`, then `C`, in parallel. template HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView A, - const bool A_padded, const MatPtrT& B, - RowPtrs C_rows) const { + const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); @@ -995,8 +946,8 @@ class MMPerPackage { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, - out_tag, args_, C_rows); + MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, + C_rows); } }; // loop_nc MMParallelPolicyT::ForRangesMC_NC( @@ -1129,12 +1080,9 @@ class MMPerPackage { const hn::ScalableTag dbf; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); - // View() is safe if vector multiple, or padded: for the latter, `ZeroInit` - // and weights.cc zero-initialize the padding. + // Neither A nor B require padding because `LoopKC` handles remainders. if constexpr (hwy::IsSame()) { - if (B.Cols() % NBF == 0 || HasPadding(B)) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); - } + return View(B, row_b, range_kc.begin(), range_kc.Num()); } const PackedSpan B_span = B.PaddedSpan(); @@ -1219,11 +1167,7 @@ struct MMImpl { // `K = B.Cols()`, which must match `A.Cols()`, is the number // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // are no other restrictions on shape, though performance is better when `M % 4 -// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()). -// -// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(), -// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match -// the behavior of `DecompressAndZeroPad`. We check this in debug builds. +// == 0` or `M <= 4`. // // If `add` is non-null, the row-vector `add` is added to each of the `M` rows // of `C`, which is a row-major matrix with arbitrary stride. A scale for @@ -1282,6 +1226,14 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(N % kNR == 0); + // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are + // reliable: the latter returns true for single rows, and the former may + // match `Cols` if the width matches the padding. + // Note that B is packed in matmul_test, but otherwise generally padded. + HWY_ASSERT(hwy::IsAligned(A.Row(0), env.ctx.allocator.LineBytes())); + if (A.Rows() > 1) { + HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes())); + } tuner.SetCandidates( MMCandidates(allocator, M, K, N, MMPerPackage::ABytes(), sizeof(TC), diff --git a/ops/matmul.h b/ops/matmul.h index 16028f3..752bad1 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -194,6 +194,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); void BindC(ThreadingContext& ctx, MatPtr& C); // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. +// Also used to decompress B, hence non-const. #pragma pack(push, 1) // power of two size template class StridedView { diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index aadbc56..14913a1 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -318,6 +318,7 @@ void TestAllMatMul() { ThreadingArgs threading_args; threading_args.bind = Tristate::kTrue; + ThreadingContext ctx(threading_args); MatMulEnv env(ctx); NestedPools& pools = env.ctx.pools;