mirror of https://github.com/google/gemma.cpp.git
Fix remainder handling for Paligemma
No longer attempt to skip the remainder handling because B might also be a non-padded view. PiperOrigin-RevId: 800890805
This commit is contained in:
parent
973e284ed6
commit
0ae8646731
|
|
@ -87,11 +87,6 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
|
|||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
MakeSpan(compressed.Row(r), extents.cols),
|
||||
/*packed_ofs=*/0);
|
||||
|
||||
// MatMul requires that A's padding be zero-initialized.
|
||||
hwy::ZeroBytes(
|
||||
compressed.Row(r) + extents.cols,
|
||||
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
|
||||
});
|
||||
|
||||
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
|
||||
|
|
@ -120,11 +115,6 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
|||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
MakeSpan(compressed.Row(r), extents.cols),
|
||||
/*packed_ofs=*/0);
|
||||
|
||||
// MatMul requires that B's padding be zero-initialized.
|
||||
hwy::ZeroBytes(
|
||||
compressed.Row(r) + extents.cols,
|
||||
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
|
||||
});
|
||||
|
||||
// Arbitrary value, different from 1, must match `GenerateMat`.
|
||||
|
|
|
|||
|
|
@ -444,9 +444,6 @@ static std::vector<IOBatch> MakeBatches(
|
|||
HWY_ASSERT(batches.back().Add(row_bytes, file_bytes_per_row));
|
||||
}
|
||||
offset += file_bytes_per_row;
|
||||
// Must zero-initialize the in-memory row padding, see MatMul.
|
||||
hwy::ZeroBytes(row_bytes + file_bytes_per_row,
|
||||
mem_stride_bytes - file_bytes_per_row);
|
||||
row_bytes += mem_stride_bytes;
|
||||
}
|
||||
HWY_ASSERT(offset == range.End());
|
||||
|
|
|
|||
202
ops/matmul-inl.h
202
ops/matmul-inl.h
|
|
@ -216,7 +216,7 @@ class MMKernel {
|
|||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
template <class Tag, typename TA, typename TC>
|
||||
static HWY_INLINE void A2C0(const StridedView<TA> A_view, const bool A_padded,
|
||||
static HWY_INLINE void A2C0(const StridedView<TA> A_view,
|
||||
const StridedViewBF& B_view, size_t mr,
|
||||
const IndexRange& range_mc, const size_t row_b,
|
||||
size_t kc, Tag tag, const MMArgs& args,
|
||||
|
|
@ -229,8 +229,8 @@ class MMKernel {
|
|||
// M == 1, or x86 with 8 SIMD registers:
|
||||
if (HWY_UNLIKELY(mr == 1)) {
|
||||
for (; imc < mc; ++imc) {
|
||||
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -239,13 +239,13 @@ class MMKernel {
|
|||
if (HWY_UNLIKELY(mr == 2)) {
|
||||
if (HWY_LIKELY(mc >= 2)) {
|
||||
for (; imc <= mc - 2; imc += 2) {
|
||||
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(imc != mc)) {
|
||||
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -253,20 +253,18 @@ class MMKernel {
|
|||
HWY_DASSERT(mr == 4);
|
||||
if (HWY_LIKELY(mc >= 4)) {
|
||||
for (; imc <= mc - 4; imc += 4) {
|
||||
LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
}
|
||||
}
|
||||
const size_t remainder_mc = mc - imc;
|
||||
HWY_DASSERT(remainder_mc < 4);
|
||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||
imc += 2;
|
||||
}
|
||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||
imc += 1;
|
||||
}
|
||||
HWY_DASSERT(imc == mc);
|
||||
|
|
@ -380,7 +378,6 @@ class MMKernel {
|
|||
// `B` is BF16, `A` and `C` can be F32 or BF16.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC>
|
||||
static HWY_INLINE void LoopKC(const StridedView<TA> A_view,
|
||||
const bool A_padded,
|
||||
const StridedViewBF& B_view, size_t row_ac,
|
||||
size_t imc, size_t col_c, size_t kc, Tag tag,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
||||
|
|
@ -405,33 +402,8 @@ class MMKernel {
|
|||
const BF16* HWY_RESTRICT br2 = B_view.Row(2);
|
||||
const BF16* HWY_RESTRICT br3 = B_view.Row(3);
|
||||
|
||||
// Ensure `A` and `B` were zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
// Only check if `A` is padded, i.e. not packed.
|
||||
if (A_padded) {
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) {
|
||||
{
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// B is unconditionally zero-padded by `DecompressAndZeroPad`.
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br0[i]) == 0.0f);
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br1[i]) == 0.0f);
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br2[i]) == 0.0f);
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br3[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
// Neither A nor B are guaranteed to be zero-padded: they might be a view
|
||||
// into the left half.
|
||||
|
||||
// Accumulate into f32.
|
||||
const hn::Repartition<float, decltype(dbf)> df;
|
||||
|
|
@ -447,18 +419,18 @@ class MMKernel {
|
|||
// The loop step is always NBF: for non-native BF16 with TA=F32, this
|
||||
// entails 2x unrolling, which helps a little.
|
||||
const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
|
||||
// If A is packed (not padded), we have to check for remainders. Otherwise,
|
||||
// we only run the main loop because A's padding is zero-initialized by
|
||||
// `ZeroInit` or weights.cc.
|
||||
const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc;
|
||||
if (kc_end >= kc_step) {
|
||||
if (kc >= kc_step) {
|
||||
HWY_UNROLL(1)
|
||||
for (; ikc <= kc_end - kc_step; ikc += kc_step) {
|
||||
for (; ikc <= kc - kc_step; ikc += kc_step) {
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
// NOTE: matmul_test has packed B so that it can call Span. The test
|
||||
// cases with non-vector-multiple K require unaligned loads here.
|
||||
// However, in actual usage, we should always have padded and thus
|
||||
// aligned A and B.
|
||||
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
|
||||
|
||||
// Should only get here if `A` is BF16, otherwise `DecompressA` would
|
||||
// convert to BF16 and `A_view` points to that.
|
||||
|
|
@ -491,10 +463,10 @@ class MMKernel {
|
|||
// shuffles.
|
||||
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
|
||||
{
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
|
||||
b0e = hn::PromoteEvenTo(df, b0);
|
||||
b1e = hn::PromoteEvenTo(df, b1);
|
||||
b2e = hn::PromoteEvenTo(df, b2);
|
||||
|
|
@ -523,10 +495,10 @@ class MMKernel {
|
|||
}
|
||||
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
|
||||
// Full-vector loads are a bit faster on SKX than half + PromoteTo.
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
|
||||
const VF b00 = hn::PromoteLowerTo(df, b0);
|
||||
const VF b10 = hn::PromoteLowerTo(df, b1);
|
||||
const VF b20 = hn::PromoteLowerTo(df, b2);
|
||||
|
|
@ -586,17 +558,16 @@ class MMKernel {
|
|||
}
|
||||
}
|
||||
|
||||
// We want the number of actual valid kc, but we may already be beyond `kc`.
|
||||
const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc;
|
||||
// Always handle remainders: even though A and B are generally padded, we
|
||||
// might have a view into the left half of A and/or B.
|
||||
const size_t remaining_kc = kc - ikc;
|
||||
HWY_DASSERT(remaining_kc < kc_step);
|
||||
HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0));
|
||||
// Last iteration: B is padded but A is not; guard its loads.
|
||||
if (HWY_UNLIKELY(remaining_kc != 0)) {
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
|
||||
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
|
||||
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
|
||||
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
|
||||
|
||||
// Should only get here if `A` is BF16, otherwise `DecompressA` would
|
||||
// convert to BF16 and `A_view` points to that.
|
||||
|
|
@ -628,10 +599,10 @@ class MMKernel {
|
|||
// lane-crossing promotion for both might be bottlenecked on shuffles.
|
||||
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
|
||||
{
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
|
||||
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
|
||||
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
|
||||
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
|
||||
b0e = hn::PromoteEvenTo(df, b0);
|
||||
b1e = hn::PromoteEvenTo(df, b1);
|
||||
b2e = hn::PromoteEvenTo(df, b2);
|
||||
|
|
@ -661,10 +632,10 @@ class MMKernel {
|
|||
C33);
|
||||
}
|
||||
} else { // IsF32<TA>(): promote half-B to F32, F32*F32.
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
|
||||
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
|
||||
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
|
||||
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
|
||||
const VF b00 = hn::PromoteLowerTo(df, b0);
|
||||
const VF b10 = hn::PromoteLowerTo(df, b1);
|
||||
const VF b20 = hn::PromoteLowerTo(df, b2);
|
||||
|
|
@ -786,12 +757,9 @@ class MMPerPackage {
|
|||
if constexpr (WantDecompressA<TA>()) {
|
||||
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
||||
DecompressA<MMParallelPolicyT>(A, A_view);
|
||||
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
|
||||
DispatchOrder(parallel_policy, A_view, A_padded, B, C_rows);
|
||||
DispatchOrder(parallel_policy, A_view, B, C_rows);
|
||||
} else {
|
||||
const bool A_padded = HasPadding(A);
|
||||
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), A_padded, B,
|
||||
C_rows);
|
||||
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -808,42 +776,29 @@ class MMPerPackage {
|
|||
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
|
||||
}
|
||||
|
||||
// Use instead of `MatPtr::IsPacked` because that returns true for single
|
||||
// rows, but we want to know whether there is padding.
|
||||
static bool HasPadding(const MatPtr& mat) {
|
||||
return mat.Stride() > mat.Cols();
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A``
|
||||
// and `B` are const, but StridedView is also used for non-const `partial`.
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
template <typename T>
|
||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||
size_t cols) {
|
||||
HWY_DASSERT(c < AB.Cols());
|
||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag<T>());
|
||||
(void)N;
|
||||
// If `AB` is padded, then `LoopKC` expects the view is either a vector
|
||||
// multiple, or all columns and thus also padded.
|
||||
HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols()));
|
||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||
}
|
||||
|
||||
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DispatchOrder(const MMParallelPolicyT& parallel_policy,
|
||||
const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B,
|
||||
const StridedView<TA> A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
return DoNT<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
||||
case MMOrder::kNT_K:
|
||||
return DoNT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
return DoNT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
||||
case MMOrder::kNT_MT:
|
||||
return DoNT_MT<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
return DoNT_MT<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
||||
case MMOrder::kNT_MT_K:
|
||||
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, A_padded, B, C_rows);
|
||||
return DoNT_MT_K<TA, TB, TC>(parallel_policy, A, B, C_rows);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
|
|
@ -852,8 +807,7 @@ class MMPerPackage {
|
|||
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
|
|
@ -878,8 +832,8 @@ class MMPerPackage {
|
|||
row_b += kNR) {
|
||||
StridedViewBF B_view =
|
||||
DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K,
|
||||
MMSetC(), args_, C_rows);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -887,8 +841,7 @@ class MMPerPackage {
|
|||
// Single M range, parallel N, sequential K. Sets C, then accumulates.
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT_K(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
||||
|
|
@ -909,8 +862,8 @@ class MMPerPackage {
|
|||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
|
||||
out_tag, args_, C_rows);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -937,8 +890,7 @@ class MMPerPackage {
|
|||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT_MT(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
|
|
@ -963,8 +915,8 @@ class MMPerPackage {
|
|||
row_b += kNR) {
|
||||
const StridedViewBF B_view =
|
||||
DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K,
|
||||
MMSetC(), args_, C_rows);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -973,8 +925,7 @@ class MMPerPackage {
|
|||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TA, typename TB, typename TC, class MMParallelPolicyT>
|
||||
HWY_INLINE void DoNT_MT_K(MMParallelPolicyT, const StridedView<TA> A,
|
||||
const bool A_padded, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||
|
|
@ -995,8 +946,8 @@ class MMPerPackage {
|
|||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
|
||||
out_tag, args_, C_rows);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
}
|
||||
}; // loop_nc
|
||||
MMParallelPolicyT::ForRangesMC_NC(
|
||||
|
|
@ -1129,12 +1080,9 @@ class MMPerPackage {
|
|||
const hn::ScalableTag<BF16> dbf;
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
// View() is safe if vector multiple, or padded: for the latter, `ZeroInit`
|
||||
// and weights.cc zero-initialize the padding.
|
||||
// Neither A nor B require padding because `LoopKC` handles remainders.
|
||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||
if (B.Cols() % NBF == 0 || HasPadding(B)) {
|
||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||
}
|
||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||
}
|
||||
|
||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||
|
|
@ -1219,11 +1167,7 @@ struct MMImpl {
|
|||
// `K = B.Cols()`, which must match `A.Cols()`, is the number
|
||||
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
|
||||
// are no other restrictions on shape, though performance is better when `M % 4
|
||||
// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()).
|
||||
//
|
||||
// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(),
|
||||
// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match
|
||||
// the behavior of `DecompressAndZeroPad`. We check this in debug builds.
|
||||
// == 0` or `M <= 4`.
|
||||
//
|
||||
// If `add` is non-null, the row-vector `add` is added to each of the `M` rows
|
||||
// of `C`, which is a row-major matrix with arbitrary stride. A scale for
|
||||
|
|
@ -1282,6 +1226,14 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
HWY_ASSERT(M <= MMStorage::kMaxM);
|
||||
HWY_ASSERT(K <= MMStorage::kMaxK);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
// Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are
|
||||
// reliable: the latter returns true for single rows, and the former may
|
||||
// match `Cols` if the width matches the padding.
|
||||
// Note that B is packed in matmul_test, but otherwise generally padded.
|
||||
HWY_ASSERT(hwy::IsAligned(A.Row(0), env.ctx.allocator.LineBytes()));
|
||||
if (A.Rows() > 1) {
|
||||
HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes()));
|
||||
}
|
||||
|
||||
tuner.SetCandidates(
|
||||
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC),
|
||||
|
|
|
|||
|
|
@ -194,6 +194,7 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC);
|
|||
void BindC(ThreadingContext& ctx, MatPtr& C);
|
||||
|
||||
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
||||
// Also used to decompress B, hence non-const.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class StridedView {
|
||||
|
|
|
|||
|
|
@ -318,6 +318,7 @@ void TestAllMatMul() {
|
|||
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.bind = Tristate::kTrue;
|
||||
|
||||
ThreadingContext ctx(threading_args);
|
||||
MatMulEnv env(ctx);
|
||||
NestedPools& pools = env.ctx.pools;
|
||||
|
|
|
|||
Loading…
Reference in New Issue