Fix remainder handling for Paligemma

No longer attempt to skip the remainder handling because B might also be a non-padded view.

PiperOrigin-RevId: 800890805
This commit is contained in:
Jan Wassenberg 2025-08-29 07:25:14 -07:00 committed by Copybara-Service
parent 973e284ed6
commit 0ae8646731
5 changed files with 79 additions and 138 deletions

View File

@ -87,11 +87,6 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
// MatMul requires that A's padding be zero-initialized.
hwy::ZeroBytes(
compressed.Row(r) + extents.cols,
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
});
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
@ -120,11 +115,6 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
// MatMul requires that B's padding be zero-initialized.
hwy::ZeroBytes(
compressed.Row(r) + extents.cols,
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
});
// Arbitrary value, different from 1, must match `GenerateMat`.

View File

@ -444,9 +444,6 @@ static std::vector<IOBatch> MakeBatches(
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
}
offset += file_bytes_per_row;
// Must zero-initialize the in-memory row padding, see MatMul.
hwy::ZeroBytes(row_bytes + file_bytes_per_row,
mem_stride_bytes - file_bytes_per_row);
row_bytes += mem_stride_bytes;
}
HWY_ASSERT(offset == range.End());

View File

@ -216,7 +216,7 @@ class MMKernel {
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag, typename TA, typename TC>
static HWY_INLINE void A2C0(const StridedView<TA> A_view, const bool A_padded,
static HWY_INLINE void A2C0(const StridedView<TA> A_view,
const StridedViewBF& B_view, size_t mr,
const IndexRange& range_mc, const size_t row_b,
size_t kc, Tag tag, const MMArgs& args,
@ -229,8 +229,8 @@ class MMKernel {
// M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) {
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
}
return;
}
@ -239,13 +239,13 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
}
}
if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
}
return;
}
@ -253,20 +253,18 @@ class MMKernel {
HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
}
}
const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
imc += 2;
}
if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
imc += 1;
}
HWY_DASSERT(imc == mc);
@ -380,7 +378,6 @@ class MMKernel {
// `B` is BF16, `A` and `C` can be F32 or BF16.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC>
static HWY_INLINE void LoopKC(const StridedView<TA> A_view,
const bool A_padded,
const StridedViewBF& B_view, size_t row_ac,
size_t imc, size_t col_c, size_t kc, Tag tag,
const MMArgs& args, RowPtrs<TC> C_rows) {
@ -405,33 +402,8 @@ class MMKernel {
const BF16* HWY_RESTRICT br2 = B_view.Row(2);
const BF16* HWY_RESTRICT br3 = B_view.Row(3);
// Ensure `A` and `B` were zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
// Only check if `A` is padded, i.e. not packed.
if (A_padded) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) {
{
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
}
if constexpr (kRowsAC > 1) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
}
if constexpr (kRowsAC > 2) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
}
if constexpr (kRowsAC > 3) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
}
}
}
// B is unconditionally zero-padded by `DecompressAndZeroPad`.
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(br0[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br1[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br2[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br3[i]) == 0.0f);
}
}
// Neither A nor B are guaranteed to be zero-padded: they might be a view
// into the left half.
// Accumulate into f32.
const hn::Repartition<float, decltype(dbf)> df;
@ -447,18 +419,18 @@ class MMKernel {
// The loop step is always NBF: for non-native BF16 with TA=F32, this
// entails 2x unrolling, which helps a little.
const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
// If A is packed (not padded), we have to check for remainders. Otherwise,
// we only run the main loop because A's padding is zero-initialized by
// `ZeroInit` or weights.cc.
const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc;
if (kc_end >= kc_step) {
if (kc >= kc_step) {
HWY_UNROLL(1)
for (; ikc <= kc_end - kc_step; ikc += kc_step) {
for (; ikc <= kc - kc_step; ikc += kc_step) {
if constexpr (HWY_NATIVE_DOT_BF16) {
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
// NOTE: matmul_test has packed B so that it can call Span. The test
// cases with non-vector-multiple K require unaligned loads here.
// However, in actual usage, we should always have padded and thus
// aligned A and B.
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that.
@ -491,10 +463,10 @@ class MMKernel {
// shuffles.
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
@ -523,10 +495,10 @@ class MMKernel {
}
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
// Full-vector loads are a bit faster on SKX than half + PromoteTo.
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
const VF b00 = hn::PromoteLowerTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2);
@ -586,17 +558,16 @@ class MMKernel {
}
}
// We want the number of actual valid kc, but we may already be beyond `kc`.
const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc;
// Always handle remainders: even though A and B are generally padded, we
// might have a view into the left half of A and/or B.
const size_t remaining_kc = kc - ikc;
HWY_DASSERT(remaining_kc < kc_step);
HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0));
// Last iteration: B is padded but A is not; guard its loads.
if (HWY_UNLIKELY(remaining_kc != 0)) {
if constexpr (HWY_NATIVE_DOT_BF16) {
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that.
@ -628,10 +599,10 @@ class MMKernel {
// lane-crossing promotion for both might be bottlenecked on shuffles.
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
@ -661,10 +632,10 @@ class MMKernel {
C33);
}
} else { // IsF32<TA>(): promote half-B to F32, F32*F32.
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
const VF b00 = hn::PromoteLowerTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2);
@ -786,12 +757,9 @@ class MMPerPackage {
if constexpr (WantDecompressA<TA>()) {
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
DecompressA<MMParallelPolicyT>(A, A_view);
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
DispatchOrder(parallel_policy, A_view, B, C_rows);
} else {
const bool A_padded = HasPadding(A);
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
C_rows);
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
}
}
@ -808,42 +776,29 @@ class MMPerPackage {
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
}
// Use instead of `MatPtr::IsPacked` because that returns true for single
// rows, but we want to know whether there is padding.
static bool HasPadding(const MatPtr& mat) {
return mat.Stride() > mat.Cols();
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A``
// and `B` are const, but StridedView is also used for non-const `partial`.
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag<T>());
(void)N;
// If `AB` is padded, then `LoopKC` expects the view is either a vector
// multiple, or all columns and thus also padded.
HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols()));
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B,
const StridedView<TA> A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
switch (order_) {
case MMOrder::kNT:
return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
return DoNT<TA, TB, TC>(parallel_policy, A, B, C_rows);
case MMOrder::kNT_K:
return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
return DoNT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
case MMOrder::kNT_MT:
return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
return DoNT_MT<TA, TB, TC>(parallel_policy, A, B, C_rows);
case MMOrder::kNT_MT_K:
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
default:
HWY_UNREACHABLE;
}
@ -852,8 +807,7 @@ class MMPerPackage {
// Single M and K ranges, parallel N. Fills all of C directly.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -878,8 +832,8 @@ class MMPerPackage {
row_b += kNR) {
StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K,
MMSetC(), args_, C_rows);
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_, C_rows);
}
});
}
@ -887,8 +841,7 @@ class MMPerPackage {
// Single M range, parallel N, sequential K. Sets C, then accumulates.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0);
@ -909,8 +862,8 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
out_tag, args_, C_rows);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
}
};
@ -937,8 +890,7 @@ class MMPerPackage {
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0);
@ -963,8 +915,8 @@ class MMPerPackage {
row_b += kNR) {
const StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K,
MMSetC(), args_, C_rows);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_, C_rows);
}
});
}
@ -973,8 +925,7 @@ class MMPerPackage {
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
const bool A_padded, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
@ -995,8 +946,8 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
out_tag, args_, C_rows);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
}
}; // loop_nc
MMParallelPolicyT::ForRangesMC_NC(
@ -1129,13 +1080,10 @@ class MMPerPackage {
const hn::ScalableTag<BF16> dbf;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
// View() is safe if vector multiple, or padded: for the latter, `ZeroInit`
// and weights.cc zero-initialize the padding.
// Neither A nor B require padding because `LoopKC` handles remainders.
if constexpr (hwy::IsSame<TB, BF16>()) {
if (B.Cols() % NBF == 0 || HasPadding(B)) {
return View(B, row_b, range_kc.begin(), range_kc.Num());
}
}
const PackedSpan<const TB> B_span = B.PaddedSpan();
@ -1219,11 +1167,7 @@ struct MMImpl {
// `K = B.Cols()`, which must match `A.Cols()`, is the number
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
// are no other restrictions on shape, though performance is better when `M % 4
// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()).
//
// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(),
// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match
// the behavior of `DecompressAndZeroPad`. We check this in debug builds.
// == 0` or `M <= 4`.
//
// If `add` is non-null, the row-vector `add` is added to each of the `M` rows
// of `C`, which is a row-major matrix with arbitrary stride. A scale for
@ -1282,6 +1226,14 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(M <= MMStorage::kMaxM);
HWY_ASSERT(K <= MMStorage::kMaxK);
HWY_ASSERT(N % kNR == 0);
// Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are
// reliable: the latter returns true for single rows, and the former may
// match `Cols` if the width matches the padding.
// Note that B is packed in matmul_test, but otherwise generally padded.
HWY_ASSERT(hwy::IsAligned(A.Row(0), env.ctx.allocator.LineBytes()));
if (A.Rows() > 1) {
HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes()));
}
tuner.SetCandidates(
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC),

View File

@ -194,6 +194,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
void BindC(ThreadingContext& ctx, MatPtr& C);
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
// Also used to decompress B, hence non-const.
#pragma pack(push, 1) // power of two size
template <typename T>
class StridedView {

View File

@ -318,6 +318,7 @@ void TestAllMatMul() {
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
ThreadingContext ctx(threading_args);
MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools;