Fix remainder handling for Paligemma

No longer attempt to skip the remainder handling because B might also be a non-padded view.

PiperOrigin-RevId: 800890805
This commit is contained in:
Jan Wassenberg 2025-08-29 07:25:14 -07:00 committed by Copybara-Service
parent 973e284ed6
commit 0ae8646731
5 changed files with 79 additions and 138 deletions

View File

@ -87,11 +87,6 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
Compress(raw.Row(r), raw.Cols(), ws.tls[thread], Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols), MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0); /*packed_ofs=*/0);
// MatMul requires that A's padding be zero-initialized.
hwy::ZeroBytes(
compressed.Row(r) + extents.cols,
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
}); });
compressed.SetScale(0.6f); // Arbitrary value, different from 1. compressed.SetScale(0.6f); // Arbitrary value, different from 1.
@ -120,11 +115,6 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
Compress(raw.Row(r), raw.Cols(), ws.tls[thread], Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols), MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0); /*packed_ofs=*/0);
// MatMul requires that B's padding be zero-initialized.
hwy::ZeroBytes(
compressed.Row(r) + extents.cols,
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
}); });
// Arbitrary value, different from 1, must match `GenerateMat`. // Arbitrary value, different from 1, must match `GenerateMat`.

View File

@ -444,9 +444,6 @@ static std::vector<IOBatch> MakeBatches(
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row)); HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
} }
offset += file_bytes_per_row; offset += file_bytes_per_row;
// Must zero-initialize the in-memory row padding, see MatMul.
hwy::ZeroBytes(row_bytes + file_bytes_per_row,
mem_stride_bytes - file_bytes_per_row);
row_bytes += mem_stride_bytes; row_bytes += mem_stride_bytes;
} }
HWY_ASSERT(offset == range.End()); HWY_ASSERT(offset == range.End());

View File

@ -216,7 +216,7 @@ class MMKernel {
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag, typename TA, typename TC> template <class Tag, typename TA, typename TC>
static HWY_INLINE void A2C0(const StridedView<TA> A_view, const bool A_padded, static HWY_INLINE void A2C0(const StridedView<TA> A_view,
const StridedViewBF& B_view, size_t mr, const StridedViewBF& B_view, size_t mr,
const IndexRange& range_mc, const size_t row_b, const IndexRange& range_mc, const size_t row_b,
size_t kc, Tag tag, const MMArgs& args, size_t kc, Tag tag, const MMArgs& args,
@ -229,8 +229,8 @@ class MMKernel {
// M == 1, or x86 with 8 SIMD registers: // M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) { if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) { for (; imc < mc; ++imc) {
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
args, C_rows); C_rows);
} }
return; return;
} }
@ -239,13 +239,13 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) { if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) { if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) { for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
args, C_rows); C_rows);
} }
} }
if (HWY_UNLIKELY(imc != mc)) { if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
args, C_rows); C_rows);
} }
return; return;
} }
@ -253,20 +253,18 @@ class MMKernel {
HWY_DASSERT(mr == 4); HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) { if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) { for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
args, C_rows); C_rows);
} }
} }
const size_t remainder_mc = mc - imc; const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4); HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) { if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
C_rows);
imc += 2; imc += 2;
} }
if (HWY_UNLIKELY(remainder_mc & 1)) { if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args, LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
C_rows);
imc += 1; imc += 1;
} }
HWY_DASSERT(imc == mc); HWY_DASSERT(imc == mc);
@ -380,7 +378,6 @@ class MMKernel {
// `B` is BF16, `A` and `C` can be F32 or BF16. // `B` is BF16, `A` and `C` can be F32 or BF16.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC> template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC>
static HWY_INLINE void LoopKC(const StridedView<TA> A_view, static HWY_INLINE void LoopKC(const StridedView<TA> A_view,
const bool A_padded,
const StridedViewBF& B_view, size_t row_ac, const StridedViewBF& B_view, size_t row_ac,
size_t imc, size_t col_c, size_t kc, Tag tag, size_t imc, size_t col_c, size_t kc, Tag tag,
const MMArgs& args, RowPtrs<TC> C_rows) { const MMArgs& args, RowPtrs<TC> C_rows) {
@ -405,33 +402,8 @@ class MMKernel {
const BF16* HWY_RESTRICT br2 = B_view.Row(2); const BF16* HWY_RESTRICT br2 = B_view.Row(2);
const BF16* HWY_RESTRICT br3 = B_view.Row(3); const BF16* HWY_RESTRICT br3 = B_view.Row(3);
// Ensure `A` and `B` were zero-padded. // Neither A nor B are guaranteed to be zero-padded: they might be a view
if constexpr (HWY_IS_DEBUG_BUILD) { // into the left half.
// Only check if `A` is padded, i.e. not packed.
if (A_padded) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) {
{
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
}
if constexpr (kRowsAC > 1) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
}
if constexpr (kRowsAC > 2) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
}
if constexpr (kRowsAC > 3) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
}
}
}
// B is unconditionally zero-padded by `DecompressAndZeroPad`.
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(br0[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br1[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br2[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br3[i]) == 0.0f);
}
}
// Accumulate into f32. // Accumulate into f32.
const hn::Repartition<float, decltype(dbf)> df; const hn::Repartition<float, decltype(dbf)> df;
@ -447,18 +419,18 @@ class MMKernel {
// The loop step is always NBF: for non-native BF16 with TA=F32, this // The loop step is always NBF: for non-native BF16 with TA=F32, this
// entails 2x unrolling, which helps a little. // entails 2x unrolling, which helps a little.
const HWY_LANES_CONSTEXPR size_t kc_step = NBF; const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
// If A is packed (not padded), we have to check for remainders. Otherwise, if (kc >= kc_step) {
// we only run the main loop because A's padding is zero-initialized by
// `ZeroInit` or weights.cc.
const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc;
if (kc_end >= kc_step) {
HWY_UNROLL(1) HWY_UNROLL(1)
for (; ikc <= kc_end - kc_step; ikc += kc_step) { for (; ikc <= kc - kc_step; ikc += kc_step) {
if constexpr (HWY_NATIVE_DOT_BF16) { if constexpr (HWY_NATIVE_DOT_BF16) {
const VBF b0 = hn::Load(dbf, br0 + ikc); // NOTE: matmul_test has packed B so that it can call Span. The test
const VBF b1 = hn::Load(dbf, br1 + ikc); // cases with non-vector-multiple K require unaligned loads here.
const VBF b2 = hn::Load(dbf, br2 + ikc); // However, in actual usage, we should always have padded and thus
const VBF b3 = hn::Load(dbf, br3 + ikc); // aligned A and B.
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would // Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that. // convert to BF16 and `A_view` points to that.
@ -491,10 +463,10 @@ class MMKernel {
// shuffles. // shuffles.
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{ {
const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc);
b0e = hn::PromoteEvenTo(df, b0); b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1); b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2); b2e = hn::PromoteEvenTo(df, b2);
@ -523,10 +495,10 @@ class MMKernel {
} }
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32. } else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
// Full-vector loads are a bit faster on SKX than half + PromoteTo. // Full-vector loads are a bit faster on SKX than half + PromoteTo.
const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc);
const VF b00 = hn::PromoteLowerTo(df, b0); const VF b00 = hn::PromoteLowerTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1); const VF b10 = hn::PromoteLowerTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2); const VF b20 = hn::PromoteLowerTo(df, b2);
@ -586,17 +558,16 @@ class MMKernel {
} }
} }
// We want the number of actual valid kc, but we may already be beyond `kc`. // Always handle remainders: even though A and B are generally padded, we
const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc; // might have a view into the left half of A and/or B.
const size_t remaining_kc = kc - ikc;
HWY_DASSERT(remaining_kc < kc_step); HWY_DASSERT(remaining_kc < kc_step);
HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0));
// Last iteration: B is padded but A is not; guard its loads.
if (HWY_UNLIKELY(remaining_kc != 0)) { if (HWY_UNLIKELY(remaining_kc != 0)) {
if constexpr (HWY_NATIVE_DOT_BF16) { if constexpr (HWY_NATIVE_DOT_BF16) {
const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::Load(dbf, br3 + ikc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would // Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that. // convert to BF16 and `A_view` points to that.
@ -628,10 +599,10 @@ class MMKernel {
// lane-crossing promotion for both might be bottlenecked on shuffles. // lane-crossing promotion for both might be bottlenecked on shuffles.
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{ {
const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::Load(dbf, br3 + ikc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
b0e = hn::PromoteEvenTo(df, b0); b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1); b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2); b2e = hn::PromoteEvenTo(df, b2);
@ -661,10 +632,10 @@ class MMKernel {
C33); C33);
} }
} else { // IsF32<TA>(): promote half-B to F32, F32*F32. } else { // IsF32<TA>(): promote half-B to F32, F32*F32.
const VBF b0 = hn::Load(dbf, br0 + ikc); const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::Load(dbf, br1 + ikc); const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::Load(dbf, br2 + ikc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::Load(dbf, br3 + ikc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
const VF b00 = hn::PromoteLowerTo(df, b0); const VF b00 = hn::PromoteLowerTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1); const VF b10 = hn::PromoteLowerTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2); const VF b20 = hn::PromoteLowerTo(df, b2);
@ -786,12 +757,9 @@ class MMPerPackage {
if constexpr (WantDecompressA<TA>()) { if constexpr (WantDecompressA<TA>()) {
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
DecompressA<MMParallelPolicyT>(A, A_view); DecompressA<MMParallelPolicyT>(A, A_view);
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded. DispatchOrder(parallel_policy, A_view, B, C_rows);
DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
} else { } else {
const bool A_padded = HasPadding(A); DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
C_rows);
} }
} }
@ -808,42 +776,29 @@ class MMPerPackage {
return HWY_MAX(kNR, line_bytes_ / sizeof_TC); return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
} }
// Use instead of `MatPtr::IsPacked` because that returns true for single // Returns 2D subrange whose top-left is `r, c` and width is `cols`.
// rows, but we want to know whether there is padding.
static bool HasPadding(const MatPtr& mat) {
return mat.Stride() > mat.Cols();
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A``
// and `B` are const, but StridedView is also used for non-const `partial`.
template <typename T> template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c, static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) { size_t cols) {
HWY_DASSERT(c < AB.Cols()); HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c); HWY_DASSERT(cols <= AB.Cols() - c);
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag<T>());
(void)N;
// If `AB` is padded, then `LoopKC` expects the view is either a vector
// multiple, or all columns and thus also padded.
HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols()));
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride()); return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
} }
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`. // `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy, HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
const StridedView<TA> A, const bool A_padded, const StridedView<TA> A, const MatPtrT<TB>& B,
const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const { RowPtrs<TC> C_rows) const {
switch (order_) { switch (order_) {
case MMOrder::kNT: case MMOrder::kNT:
return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows); return DoNT<TA, TB, TC>(parallel_policy, A, B, C_rows);
case MMOrder::kNT_K: case MMOrder::kNT_K:
return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows); return DoNT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
case MMOrder::kNT_MT: case MMOrder::kNT_MT:
return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows); return DoNT_MT<TA, TB, TC>(parallel_policy, A, B, C_rows);
case MMOrder::kNT_MT_K: case MMOrder::kNT_MT_K:
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows); return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
@ -852,8 +807,7 @@ class MMPerPackage {
// Single M and K ranges, parallel N. Fills all of C directly. // Single M and K ranges, parallel N. Fills all of C directly.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B, const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -878,8 +832,8 @@ class MMPerPackage {
row_b += kNR) { row_b += kNR) {
StridedViewBF B_view = StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view); DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K, MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
MMSetC(), args_, C_rows); args_, C_rows);
} }
}); });
} }
@ -887,8 +841,7 @@ class MMPerPackage {
// Single M range, parallel N, sequential K. Sets C, then accumulates. // Single M range, parallel N, sequential K. Sets C, then accumulates.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B, const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0); const IndexRange& range_mc = ranges_mc_.Range(0);
@ -909,8 +862,8 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
out_tag, args_, C_rows); C_rows);
} }
}; };
@ -937,8 +890,7 @@ class MMPerPackage {
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B, const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0);
@ -963,8 +915,8 @@ class MMPerPackage {
row_b += kNR) { row_b += kNR) {
const StridedViewBF B_view = const StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view); DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
MMSetC(), args_, C_rows); args_, C_rows);
} }
}); });
} }
@ -973,8 +925,7 @@ class MMPerPackage {
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT> template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A, HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B, const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC); HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
@ -995,8 +946,8 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
out_tag, args_, C_rows); C_rows);
} }
}; // loop_nc }; // loop_nc
MMParallelPolicyT::ForRangesMC_NC( MMParallelPolicyT::ForRangesMC_NC(
@ -1129,13 +1080,10 @@ class MMPerPackage {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
// View() is safe if vector multiple, or padded: for the latter, `ZeroInit` // Neither A nor B require padding because `LoopKC` handles remainders.
// and weights.cc zero-initialize the padding.
if constexpr (hwy::IsSame<TB, BF16>()) { if constexpr (hwy::IsSame<TB, BF16>()) {
if (B.Cols() % NBF == 0 || HasPadding(B)) {
return View(B, row_b, range_kc.begin(), range_kc.Num()); return View(B, row_b, range_kc.begin(), range_kc.Num());
} }
}
const PackedSpan<const TB> B_span = B.PaddedSpan(); const PackedSpan<const TB> B_span = B.PaddedSpan();
@ -1219,11 +1167,7 @@ struct MMImpl {
// `K = B.Cols()`, which must match `A.Cols()`, is the number // `K = B.Cols()`, which must match `A.Cols()`, is the number
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
// are no other restrictions on shape, though performance is better when `M % 4 // are no other restrictions on shape, though performance is better when `M % 4
// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()). // == 0` or `M <= 4`.
//
// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(),
// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match
// the behavior of `DecompressAndZeroPad`. We check this in debug builds.
// //
// If `add` is non-null, the row-vector `add` is added to each of the `M` rows // If `add` is non-null, the row-vector `add` is added to each of the `M` rows
// of `C`, which is a row-major matrix with arbitrary stride. A scale for // of `C`, which is a row-major matrix with arbitrary stride. A scale for
@ -1282,6 +1226,14 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(M <= MMStorage::kMaxM);
HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(K <= MMStorage::kMaxK);
HWY_ASSERT(N % kNR == 0); HWY_ASSERT(N % kNR == 0);
// Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are
// reliable: the latter returns true for single rows, and the former may
// match `Cols` if the width matches the padding.
// Note that B is packed in matmul_test, but otherwise generally padded.
HWY_ASSERT(hwy::IsAligned(A.Row(0), env.ctx.allocator.LineBytes()));
if (A.Rows() > 1) {
HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes()));
}
tuner.SetCandidates( tuner.SetCandidates(
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC), MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC),

View File

@ -194,6 +194,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
void BindC(ThreadingContext& ctx, MatPtr& C); void BindC(ThreadingContext& ctx, MatPtr& C);
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. // Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
// Also used to decompress B, hence non-const.
#pragma pack(push, 1) // power of two size #pragma pack(push, 1) // power of two size
template <typename T> template <typename T>
class StridedView { class StridedView {

View File

@ -318,6 +318,7 @@ void TestAllMatMul() {
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue; threading_args.bind = Tristate::kTrue;
ThreadingContext ctx(threading_args); ThreadingContext ctx(threading_args);
MatMulEnv env(ctx); MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools; NestedPools& pools = env.ctx.pools;