f32 LoopKC: 1.37x(M=512), 1.19(M=128) single-K F32,BF16 matmul speedup on SKX

Add a special case for A=F32,B=BF16, used when there is no native bf16 dot product.

dot-inl: ensure bf16,f32 and f32,bf16 both get promoted to float before f64 summation
matmul.cc: update autotuning to reflect actual A size
matmul_test: add all combinations of bf16/f32, report all results, not just first difference, check non-vector-aligned K
PiperOrigin-RevId: 800487817
This commit is contained in:
Jan Wassenberg 2025-08-28 08:55:15 -07:00 committed by Copybara-Service
parent 98ddc166db
commit 31c09cca4c
6 changed files with 594 additions and 200 deletions

View File

@ -85,15 +85,21 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), compressed.Cols()),
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
// MatMul requires that A's padding be zero-initialized.
hwy::ZeroBytes(
compressed.Row(r) + extents.cols,
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
});
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
return compressed;
}
// Same, but `extents` describes the transposed matrix.
// Same, but `extents` describes the transposed matrix and the computation of
// `f` swaps `r` and `c`.
template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
const Allocator& allocator,
@ -112,8 +118,13 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
row[c] = f;
}
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
MakeSpan(compressed.Row(r), compressed.Cols()),
MakeSpan(compressed.Row(r), extents.cols),
/*packed_ofs=*/0);
// MatMul requires that B's padding be zero-initialized.
hwy::ZeroBytes(
compressed.Row(r) + extents.cols,
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
});
// Arbitrary value, different from 1, must match `GenerateMat`.

View File

@ -157,15 +157,16 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
// promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA.
struct DotKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
// to be `float` in order to have `Raw = double`. To avoid loss of accuracy,
// if either is float, we decompress both to float, otherwise `BF16`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double,
hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>>;
using State = double;
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3,
VR&, VR&, VR&, VR&) const {
@ -175,6 +176,41 @@ struct DotKernelDouble {
sum3 = hn::MulAdd(w3, v3, sum3);
}
// Raw = float
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, const VR v0, const VR v1, const VR v2,
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS&, VS&, VS&, VS&) const {
const hn::Repartition<double, DRaw> dd;
using VD = hn::Vec<decltype(dd)>;
VD w0d = hn::PromoteLowerTo(dd, w0);
VD w1d = hn::PromoteLowerTo(dd, w1);
VD w2d = hn::PromoteLowerTo(dd, w2);
VD w3d = hn::PromoteLowerTo(dd, w3);
VD v0d = hn::PromoteLowerTo(dd, v0);
VD v1d = hn::PromoteLowerTo(dd, v1);
VD v2d = hn::PromoteLowerTo(dd, v2);
VD v3d = hn::PromoteLowerTo(dd, v3);
sum0 = hn::MulAdd(w0d, v0d, sum0);
sum1 = hn::MulAdd(w1d, v1d, sum1);
sum2 = hn::MulAdd(w2d, v2d, sum2);
sum3 = hn::MulAdd(w3d, v3d, sum3);
w0d = hn::PromoteUpperTo(dd, w0);
w1d = hn::PromoteUpperTo(dd, w1);
w2d = hn::PromoteUpperTo(dd, w2);
w3d = hn::PromoteUpperTo(dd, w3);
v0d = hn::PromoteUpperTo(dd, v0);
v1d = hn::PromoteUpperTo(dd, v1);
v2d = hn::PromoteUpperTo(dd, v2);
v3d = hn::PromoteUpperTo(dd, v3);
sum0 = hn::MulAdd(w0d, v0d, sum0);
sum1 = hn::MulAdd(w1d, v1d, sum1);
sum2 = hn::MulAdd(w2d, v2d, sum2);
sum3 = hn::MulAdd(w3d, v3d, sum3);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
@ -217,11 +253,26 @@ struct DotKernelDouble {
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VR& sum0,
VR&) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
// Raw = float
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VS& sum0,
VS&) const {
const hn::Repartition<double, DRaw> dd;
using VD = hn::Vec<decltype(dd)>;
VD w0d = hn::PromoteLowerTo(dd, w0);
VD v0d = hn::PromoteLowerTo(dd, v0);
sum0 = hn::MulAdd(w0d, v0d, sum0);
w0d = hn::PromoteUpperTo(dd, w0);
v0d = hn::PromoteUpperTo(dd, v0);
sum0 = hn::MulAdd(w0d, v0d, sum0);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>

View File

@ -113,7 +113,7 @@ class MMStoreHorizontalSumsIntoC {
const size_t row_c, const size_t col_c,
const MMArgs& args, RowPtrs<TC> C_rows) const {
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
const size_t N = hn::Lanes(df);
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
// log(N) operations for vectors of length N. Because `kNR` == 4, we
// instead use `StoreInterleaved4` for a vector length-agnostic
@ -230,7 +230,7 @@ class MMAddHorizontalSumsIntoPartial {
const hn::Repartition<double, DF> dd;
HWY_ALIGN double buf[16 * hn::MaxLanes(dd)];
using VD = hn::Vec<decltype(dd)>;
const size_t ND = hn::Lanes(dd);
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
VD C00 = SumOfPromotedPairs(dd, F00);
VD C01 = SumOfPromotedPairs(dd, F01);
VD C02 = SumOfPromotedPairs(dd, F02);
@ -351,8 +351,8 @@ class MMKernel {
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag, typename TC>
static HWY_INLINE void A2C0(const StridedViewBF& A_view,
template <class Tag, typename TA, typename TC>
static HWY_INLINE void A2C0(const StridedView<TA> A_view, const bool A_padded,
const StridedViewBF& B_view, size_t mr,
const IndexRange& range_mc, const size_t row_b,
size_t kc, Tag tag, const MMArgs& args,
@ -365,8 +365,8 @@ class MMKernel {
// M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
}
return;
}
@ -375,13 +375,13 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
}
}
if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
}
return;
}
@ -389,18 +389,20 @@ class MMKernel {
HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
args, C_rows);
}
}
const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
imc += 2;
}
if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
imc += 1;
}
HWY_DASSERT(imc == mc);
@ -423,9 +425,10 @@ class MMKernel {
// `MMAddHorizontalSumsIntoPartial`.
template <class DBF, class VBF = hn::Vec<DBF>,
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
static HWY_INLINE void ElementwiseMulAcc(DBF dbf, VBF a, VBF b0, VBF b1,
VBF b2, VBF b3, VF& C0, VF& C1,
VF& C2, VF& C3) {
static HWY_INLINE void ElementwiseMulAccNativeBF(DBF dbf, VBF a, VBF b0,
VBF b1, VBF b2, VBF b3,
VF& C0, VF& C1, VF& C2,
VF& C3) {
// This handles a single row of A, so the horizontal sums of `C0..3` are the
// (partial) dot products for 4 consecutive values in one row of C.
static_assert(kNR == 4);
@ -443,16 +446,17 @@ class MMKernel {
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
}
// Like `ElementwiseMulAcc`, but splits BF16 inputs into odd and even f32
// for use with FMA. Also handles two rows at a time to hide the FMA latency
// (we assume 4 cycles and dual-issue) before writing `C00` again.
// Like `ElementwiseMulAccNativeBF`, but splits BF16 inputs into odd and even
// f32 for use with FMA. Also handles two rows at a time to hide the FMA
// latency (we assume 4 cycles and dual-issue) before writing `C00` again.
template <class DBF, class VBF = hn::Vec<DBF>,
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
static HWY_INLINE void ElementwiseMulAcc2(DBF dbf, VBF a0, VBF a1, VF b0o,
VF b0e, VF b1o, VF b1e, VF b2o,
VF b2e, VF b3o, VF b3e, VF& C00,
VF& C01, VF& C02, VF& C03, VF& C10,
VF& C11, VF& C12, VF& C13) {
static HWY_INLINE void ElementwiseMulAccEmuBF(DBF dbf, VBF a0, VBF a1, VF b0o,
VF b0e, VF b1o, VF b1e, VF b2o,
VF b2e, VF b3o, VF b3e, VF& C00,
VF& C01, VF& C02, VF& C03,
VF& C10, VF& C11, VF& C12,
VF& C13) {
const DF df;
HWY_DASSERT(!HWY_NATIVE_DOT_BF16);
// Avoid `ReorderWidenMulAccumulate` because it requires extra adds for
@ -491,20 +495,36 @@ class MMKernel {
}
}
// Innermost loop over `kc` columns (typically 1024-4096) in steps of one
// vector, for `kRowsAC` rows of `A_view` from range_mc-relative `imc` and
// `B_view` from row 0 (both at column 0). Updates a `kRowsAC x kNR` tile
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
// BF16 so we can load directly without `Decompress2`, which is expensive for
// NUQ and requires 2x unrolling, which requires more loads.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
static HWY_INLINE void LoopKC(const StridedViewBF& A_view,
// For A=F32, B=BF16 without native BF16 dot product: one lane-crossing
// promotion is likely cheaper than AND+SHIFT for promoting odd/even BF.
// Caller already promoted B, so all inputs are F32.
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2,
VF b3, VF& C0, VF& C1, VF& C2,
VF& C3) {
HWY_DASSERT(!HWY_NATIVE_DOT_BF16);
C0 = hn::MulAdd(a, b0, C0);
C1 = hn::MulAdd(a, b1, C1);
C2 = hn::MulAdd(a, b2, C2);
C3 = hn::MulAdd(a, b3, C3);
}
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
// Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`.
// `B` is BF16, `A` and `C` can be F32 or BF16.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC>
static HWY_INLINE void LoopKC(const StridedView<TA> A_view,
const bool A_padded,
const StridedViewBF& B_view, size_t row_ac,
size_t imc, size_t col_c, size_t kc, Tag tag,
const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>;
const size_t NBF = hn::Lanes(dbf);
HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag<TA>());
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
HWY_DASSERT(kRowsAC <= kMaxMR);
HWY_DASSERT(col_c % kNR == 0);
@ -512,30 +532,36 @@ class MMKernel {
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
static_assert(kNR == 4);
const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
const BF16* HWY_RESTRICT br0 = B_view.Row(0);
const BF16* HWY_RESTRICT br1 = B_view.Row(1);
const BF16* HWY_RESTRICT br2 = B_view.Row(2);
const BF16* HWY_RESTRICT br3 = B_view.Row(3);
// Ensure `A` and `B` were zero-padded by `DecompressAndZeroPad`.
// Ensure `A` and `B` were zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
// Only check if `A` is padded, i.e. not packed.
if (A_padded) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) {
{
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
}
if constexpr (kRowsAC > 1) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
}
if constexpr (kRowsAC > 2) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
}
if constexpr (kRowsAC > 3) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
}
}
}
// B is unconditionally zero-padded by `DecompressAndZeroPad`.
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
{
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
}
if constexpr (kRowsAC > 1) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
}
if constexpr (kRowsAC > 2) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
}
if constexpr (kRowsAC > 3) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
}
HWY_DASSERT(hwy::ConvertScalarTo<float>(br0[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br1[i]) == 0.0f);
HWY_DASSERT(hwy::ConvertScalarTo<float>(br2[i]) == 0.0f);
@ -553,60 +579,287 @@ class MMKernel {
C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df),
C33 = hn::Zero(df);
HWY_UNROLL(1)
for (size_t ikc = 0; ikc < kc; ikc += NBF) {
size_t ikc = 0;
// The loop step is always NBF: for non-native BF16 with TA=F32, this
// entails 2x unrolling, which helps a little.
const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
// If A is packed (not padded), we have to check for remainders. Otherwise,
// we only run the main loop because A's padding is zero-initialized by
// `ZeroInit` or weights.cc.
const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc;
if (kc_end >= kc_step) {
HWY_UNROLL(1)
for (; ikc <= kc_end - kc_step; ikc += kc_step) {
if constexpr (HWY_NATIVE_DOT_BF16) {
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that.
HWY_DASSERT(IsBF16<TA>());
{
const VBF a0 = hn::Load(dbf, ar0 + ikc);
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VBF a1 = hn::Load(dbf, ar1 + ikc);
ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::Load(dbf, ar2 + ikc);
ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VBF a3 = hn::Load(dbf, ar3 + ikc);
ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32,
C33);
}
} else { // !HWY_NATIVE_DOT_BF16
if constexpr (IsBF16<TA>()) {
// When both are BF16, it is better to load promote odd/even,
// because lane-crossing promotion for both might be bottlenecked on
// shuffles.
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
b3e = hn::PromoteEvenTo(df, b3);
b0o = FastPromoteOddTo(df, b0);
b1o = FastPromoteOddTo(df, b1);
b2o = FastPromoteOddTo(df, b2);
b3o = FastPromoteOddTo(df, b3);
}
// Two rows at a time so we have 8 separate dependency chains,
// sufficient for IPC=2 and 4-cycle latency.
{
const VBF a0 = hn::Load(dbf, ar0 + ikc);
const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C00, C01, C02, C03, C10, C11,
C12, C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::Load(dbf, ar2 + ikc);
const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C20, C21, C22, C23, C30, C31,
C32, C33);
}
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
// Full-vector loads are a bit faster on SKX than half + PromoteTo.
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
const VF b00 = hn::PromoteLowerTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2);
const VF b30 = hn::PromoteLowerTo(df, b3);
const VF b01 = hn::PromoteUpperTo(df, b0);
const VF b11 = hn::PromoteUpperTo(df, b1);
const VF b21 = hn::PromoteUpperTo(df, b2);
const VF b31 = hn::PromoteUpperTo(df, b3);
{
const VF a00 = hn::Load(df, ar0 + ikc);
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VF a10 = hn::Load(df, ar1 + ikc);
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
C13);
}
// C00 is ready again. On SKX, this interleaved unrolling is faster
// than consuming all `b*1` at the end of the loop.
{
const VF a01 = hn::Load(df, ar0 + ikc + NA);
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VF a11 = hn::Load(df, ar1 + ikc + NA);
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VF a20 = hn::Load(df, ar2 + ikc);
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a30 = hn::Load(df, ar3 + ikc);
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
C33);
}
if constexpr (kRowsAC > 2) {
const VF a21 = hn::Load(df, ar2 + ikc + NA);
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a31 = hn::Load(df, ar3 + ikc + NA);
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
C33);
}
}
}
}
}
// We want the number of actual valid kc, but we may already be beyond `kc`.
const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc;
HWY_DASSERT(remaining_kc < kc_step);
HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0));
// Last iteration: B is padded but A is not; guard its loads.
if (HWY_UNLIKELY(remaining_kc != 0)) {
if constexpr (HWY_NATIVE_DOT_BF16) {
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that.
HWY_DASSERT(IsBF16<TA>());
{
const VBF a0 = hn::Load(dbf, ar0 + ikc);
ElementwiseMulAcc(dbf, a0, b0, b1, b2, b3, C00, C01, C02, C03);
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VBF a1 = hn::Load(dbf, ar1 + ikc);
ElementwiseMulAcc(dbf, a1, b0, b1, b2, b3, C10, C11, C12, C13);
const VBF a1 = hn::LoadN(dbf, ar1 + ikc, remaining_kc);
ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::Load(dbf, ar2 + ikc);
ElementwiseMulAcc(dbf, a2, b0, b1, b2, b3, C20, C21, C22, C23);
const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VBF a3 = hn::Load(dbf, ar3 + ikc);
ElementwiseMulAcc(dbf, a3, b0, b1, b2, b3, C30, C31, C32, C33);
const VBF a3 = hn::LoadN(dbf, ar3 + ikc, remaining_kc);
ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32,
C33);
}
} else {
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{
} else { // !HWY_NATIVE_DOT_BF16
if constexpr (IsBF16<TA>()) {
// When both are BF16, it is better to load promote odd/even, because
// lane-crossing promotion for both might be bottlenecked on shuffles.
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
{
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
b3e = hn::PromoteEvenTo(df, b3);
b0o = FastPromoteOddTo(df, b0);
b1o = FastPromoteOddTo(df, b1);
b2o = FastPromoteOddTo(df, b2);
b3o = FastPromoteOddTo(df, b3);
}
// Two rows at a time so we have 8 separate dependency chains,
// sufficient for IPC=2 and 4-cycle latency.
{
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
const VBF a1 =
kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0;
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C00, C01, C02, C03, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
const VBF a3 =
kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2;
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C20, C21, C22, C23, C30, C31, C32,
C33);
}
} else { // IsF32<TA>(): promote half-B to F32, F32*F32.
const VBF b0 = hn::Load(dbf, br0 + ikc);
const VBF b1 = hn::Load(dbf, br1 + ikc);
const VBF b2 = hn::Load(dbf, br2 + ikc);
const VBF b3 = hn::Load(dbf, br3 + ikc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
b3e = hn::PromoteEvenTo(df, b3);
b0o = FastPromoteOddTo(df, b0);
b1o = FastPromoteOddTo(df, b1);
b2o = FastPromoteOddTo(df, b2);
b3o = FastPromoteOddTo(df, b3);
}
const VF b00 = hn::PromoteLowerTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2);
const VF b30 = hn::PromoteLowerTo(df, b3);
const VF b01 = hn::PromoteUpperTo(df, b0);
const VF b11 = hn::PromoteUpperTo(df, b1);
const VF b21 = hn::PromoteUpperTo(df, b2);
const VF b31 = hn::PromoteUpperTo(df, b3);
{
const VBF a0 = hn::Load(dbf, ar0 + ikc);
const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
ElementwiseMulAcc2(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
b3e, C00, C01, C02, C03, C10, C11, C12, C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::Load(dbf, ar2 + ikc);
const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
ElementwiseMulAcc2(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
b3e, C20, C21, C22, C23, C30, C31, C32, C33);
const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA;
{
const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
C13);
}
// C00 is ready again. On SKX, this interleaved unrolling is faster
// than consuming all `b*1` at the end of the loop.
{
const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
C33);
}
if constexpr (kRowsAC > 2) {
const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
C33);
}
}
}
}
} // remaining_kc != 0
// This is a substantial fraction (about 1/3) of the total time, but is
// called frequently, so do not add a profiler zone.
@ -678,7 +931,7 @@ class MMScaleDemoteAdd {
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd);
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
@ -796,7 +1049,7 @@ class MMScaleDemoteAdd {
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd);
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
@ -858,41 +1111,51 @@ class MMScaleDemoteAdd {
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC.
// Its member variables avoid long argument lists in Do*().
class MMPerPackage {
public:
// Decompression is only required for F32 A and native BF16 dot products.
// If A is already BF16, we can use a view. Padding is not required
// because `LoopKC` can handle non-vector multiples. `LoopKC` also contains
// a special case for F32 `A` and non-native BF16 dot products.
template <typename TA>
MMPerPackage(const MatPtrT<TA>& A, const MMArgs& args, const MMConfig& config,
static constexpr bool WantDecompressA() {
return HWY_NATIVE_DOT_BF16 && IsF32<TA>();
}
public:
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
size_t pkg_idx, const IndexRange& range_np)
: args_(args),
pkg_idx_(pkg_idx),
// May be overwritten with a view of A, if already BF16.
A_(args_.env->storage.A(pkg_idx, A.Extents())),
range_np_(range_np),
mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.Rows())),
ranges_kc_(config.RangesOfKC(A.Cols())),
ranges_mc_(config.RangesOfMC(A.rows)),
ranges_kc_(config.RangesOfKC(A.cols)),
ranges_nc_(config.RangesOfNC(range_np)),
order_(config.Order()),
inner_tasks_(config.InnerTasks()),
out_(config.Out()),
line_bytes_(args.env->ctx.allocator.LineBytes()) {
A_ = DecompressA(A);
line_bytes_(args.env->ctx.allocator.LineBytes()) {}
// The size of `A` that will actually be used, for purposes of choosing the
// autotuning candidates. Keep in sync with the `operator()` logic below.
template <typename TA>
static constexpr size_t ABytes() {
return WantDecompressA<TA>() ? sizeof(BF16) : sizeof(TA);
}
// B is decompressed several call layers lower, but not all member functions
// depend on TB, so pass it as an argument instead of templating the class.
template <typename TB, typename TC>
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
switch (order_) {
case MMOrder::kNT:
return DoNT(B, C_rows);
case MMOrder::kNT_K:
return DoNT_K(B, C_rows);
case MMOrder::kNT_MT:
return DoNT_MT(B, C_rows);
case MMOrder::kNT_MT_K:
return DoNT_MT_K(B, C_rows);
default:
HWY_UNREACHABLE;
// B and maybe A are decompressed several call layers lower, but not all
// member functions depend on TA/TB, so pass them as an argument instead of
// templating the class.
template <typename TA, typename TB, typename TC>
HWY_NOINLINE void operator()(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
if constexpr (WantDecompressA<TA>()) {
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
DecompressA(A, A_view);
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
DispatchOrder(A_view, A_padded, B, C_rows);
} else {
const bool A_padded = HasPadding(A);
DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows);
}
}
@ -909,16 +1172,57 @@ class MMPerPackage {
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
}
// Use instead of `MatPtr::IsPacked` because that returns true for single
// rows, but we want to know whether there is padding.
static bool HasPadding(const MatPtr& mat) {
return mat.Stride() > mat.Cols();
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A``
// and `B` are const, but StridedView is also used for non-const `partial`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag<T>());
(void)N;
// If `AB` is padded, then `LoopKC` expects the view is either a vector
// multiple, or all columns and thus also padded.
HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols()));
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
template <typename TA, typename TB, typename TC>
HWY_INLINE void DispatchOrder(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
switch (order_) {
case MMOrder::kNT:
return DoNT(A, A_padded, B, C_rows);
case MMOrder::kNT_K:
return DoNT_K(A, A_padded, B, C_rows);
case MMOrder::kNT_MT:
return DoNT_MT(A, A_padded, B, C_rows);
case MMOrder::kNT_MT_K:
return DoNT_MT_K(A, A_padded, B, C_rows);
default:
HWY_UNREACHABLE;
}
}
// Single M and K ranges, parallel N. Fills all of C directly.
template <typename TB, typename TC>
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_M = ranges_mc_.Range(0);
const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num();
const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K);
const StridedView<TA> A_view = A.View(range_M.begin(), 0, K);
const size_t B_stride =
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
@ -936,8 +1240,8 @@ class MMPerPackage {
row_b += kNR) {
StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_, C_rows);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K,
MMSetC(), args_, C_rows);
}
});
@ -945,8 +1249,9 @@ class MMPerPackage {
}
// Single M range, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
const IndexRange& range_mc = ranges_mc_.Range(0);
@ -958,8 +1263,8 @@ class MMPerPackage {
const IndexRange& range_nc,
auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num();
const StridedViewBF& A_view =
A_.View(range_mc.begin(), range_kc.begin(), kc);
const StridedView<TA> A_view =
A.View(range_mc.begin(), range_kc.begin(), kc);
const StridedViewBF B_storage_view(
B_storage, kc,
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
@ -967,8 +1272,8 @@ class MMPerPackage {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
out_tag, args_, C_rows);
}
};
@ -1013,8 +1318,9 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_MT(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0);
@ -1031,7 +1337,7 @@ class MMPerPackage {
MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args_);
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
const StridedView<TA> A_view = A.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const StridedViewBF B_storage_view(B_storage, K, B_stride);
@ -1039,8 +1345,8 @@ class MMPerPackage {
row_b += kNR) {
StridedViewBF B_view =
DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_, C_rows);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K,
MMSetC(), args_, C_rows);
}
});
@ -1049,8 +1355,9 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
template <typename TA, typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
static const auto fill_zone =
args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC");
@ -1067,14 +1374,14 @@ class MMPerPackage {
const IndexRange& range_nc,
auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num();
const StridedViewBF& A_view =
A_.View(range_mc.begin(), range_kc.begin(), kc);
const StridedView<TA> A_view =
A.View(range_mc.begin(), range_kc.begin(), kc);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C_rows);
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
out_tag, args_, C_rows);
}
}; // loop_nc
args_.env->parallel.ForRangesMC_NC(
@ -1107,17 +1414,16 @@ class MMPerPackage {
});
}
// Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a
// seekable type (i.e., not NUQ) so we can use pointer arithmetic.
template <typename TA>
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
MMParA par_a) const {
const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_.Cols());
HWY_DASSERT(all_K.Num() == A_view.Cols());
const hn::ScalableTag<BF16> dbf;
const size_t NBF = hn::Lanes(dbf);
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA");
@ -1133,8 +1439,9 @@ class MMPerPackage {
// otherwise `DecompressAndZeroPad` overwrites neighbors.
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
const PackedSpan<const float> from =
MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
@ -1175,23 +1482,12 @@ class MMPerPackage {
}
// Autotuning wrapper for `DoDecompressA`.
template <typename TA>
HWY_INLINE StridedViewBF DecompressA(const MatPtrT<TA>& A) const {
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view) const {
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if vector multiple and padded (see `DoDecompressA`).
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
// Const, but cast because StridedView is also used for `partial` which
// is non-const.
return StridedViewBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
}
}
if (HWY_LIKELY(autotune.Best())) {
DoDecompressA(A, *autotune.Best());
return A_;
return DoDecompressA(A, A_view, *autotune.Best());
}
// First call: generate candidates.
@ -1204,7 +1500,7 @@ class MMPerPackage {
const MMParA& par_a = autotune.NextConfig();
const uint64_t t0 = hwy::timer::Start();
DoDecompressA(A, par_a);
DoDecompressA(A, A_view, par_a);
const uint64_t t1 =
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
@ -1213,7 +1509,6 @@ class MMPerPackage {
static_cast<double>(min_elapsed) /
hwy::platform::InvariantTicksPerSecond() * 1E6);
}
return A_;
}
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
@ -1223,12 +1518,17 @@ class MMPerPackage {
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc,
const StridedViewBF& B_view) const {
const hn::ScalableTag<BF16> dbf;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
// View() is safe if vector multiple, or padded: for the latter, `ZeroInit`
// and weights.cc zero-initialize the padding.
if constexpr (hwy::IsSame<TB, BF16>()) {
return StridedViewBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
range_kc.Num(), B.Stride());
if (B.Cols() % NBF == 0 || HasPadding(B)) {
return View(B, row_b, range_kc.begin(), range_kc.Num());
}
}
const hn::ScalableTag<BF16> dbf;
const PackedSpan<const TB> B_span = B.PaddedSpan();
const size_t kc = range_kc.Num();
@ -1240,7 +1540,7 @@ class MMPerPackage {
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
// Verify that we zero-padded.
if constexpr (HWY_IS_DEBUG_BUILD) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, hn::Lanes(dbf)); ++i) {
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
}
}
@ -1250,7 +1550,6 @@ class MMPerPackage {
const MMArgs args_; // copy for locality
const size_t pkg_idx_;
StridedViewBF A_; // view into A or pkg_A_, both of which are padded.
const IndexRange range_np_;
// From MMConfig:
@ -1293,13 +1592,14 @@ struct MMImpl {
MMZone mm_zone;
mm_zone.MaybeEnter(pkg_idx, zone, args);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B,
C_rows);
});
} else {
const size_t pkg_idx = 0;
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows);
}
}
};
@ -1310,7 +1610,7 @@ struct MMImpl {
// `K = B.Cols()`, which must match `A.Cols()`, is the number
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
// are no other restrictions on shape, though performance is better when `M % 4
// == 0` or `M <= 4`, and when A is padded (`!A.IsPacked()`).
// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()).
//
// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(),
// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match
@ -1376,8 +1676,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(N <= MMStorage::kMaxN);
HWY_ASSERT(N % kNR == 0);
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
kNR, per_key.ranges_np, env.print_config));
tuner.SetCandidates(
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC),
kMaxMR, kNR, per_key.ranges_np, env.print_config));
}
const MMConfig& cfg = tuner.NextConfig();

View File

@ -64,19 +64,21 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
class GenerateCandidates {
public:
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
size_t sizeof_TA, size_t sizeof_TC, size_t max_mr,
size_t nr, const IndexRangePartition& ranges_np,
bool print_config)
: allocator_(allocator),
M_(M),
K_(K),
N_(N),
sizeof_TA_(sizeof_TA),
sizeof_TC_(sizeof_TC),
max_mr_(max_mr),
nr_(nr),
// These influence kc/nc, but are also stored in `MMConfig` for
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
// is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size.
// up to the line size. Use BF16, not sizeof_TA, because B is BF16.
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
nc_multiple_(allocator.StepBytes() / sizeof_TC),
ranges_np_(ranges_np),
@ -176,8 +178,9 @@ class GenerateCandidates {
// subtract the output and buf, and allow using more than the actual L1
// size. This results in an overestimate, and the loop below will propose
// the next few smaller values for the autotuner to evaluate.
const size_t bytes_ab = allocator_.L1Bytes() * 3;
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
const size_t bytes_ab =
allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream));
const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16);
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
kc_max =
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
@ -224,7 +227,7 @@ class GenerateCandidates {
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
// partial.
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes();
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
HWY_DASSERT(mc_max != 0);
@ -359,6 +362,7 @@ class GenerateCandidates {
const size_t M_;
const size_t K_;
const size_t N_;
const size_t sizeof_TA_;
const size_t sizeof_TC_;
const size_t max_mr_;
@ -376,12 +380,12 @@ class GenerateCandidates {
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
size_t K, size_t N, size_t sizeof_TA,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config) {
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
ranges_np, print_config)();
return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr,
nr, ranges_np, print_config)();
}
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote

View File

@ -281,7 +281,9 @@ class MMStorage {
BindC(partial_storage_, parallel);
}
// Returns per-package matrix view.
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is
// faster than on-the-fly when native BF16 is available: it only happens once,
// not per B tile row, and the cache footprint is smaller.
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
@ -475,8 +477,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
size_t K, size_t N, size_t sizeof_TA,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config);

View File

@ -118,17 +118,26 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
// Dot() uses double-precision summation.
double tolerance = 12 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there.
if (IsF32<TA>() && IsF32<TB>()) {
tolerance += 4 * max_abs * eps_bf16;
// If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to
// BF16, so add extra tolerance.
if (IsF32<TB>()) {
tolerance += 2 * max_abs * eps_bf16;
}
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
const double rel_tolerance =
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
double max_rel = 0.0;
size_t worst_r = 0;
size_t worst_c = 0;
double worst_actual = 0.0;
double worst_expected = 0.0;
size_t num_outside = 0;
for (size_t r = 0; r < A.Rows(); r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
@ -143,15 +152,24 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E\n",
r, c, expected_value, actual_value, norm, max_abs,
tolerance, rel, max_rel);
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
max_rel = rel;
++num_outside;
}
}
}
}
if (max_rel > rel_tolerance) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E num_outside %zu\n",
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
tolerance, max_rel, rel_tolerance, num_outside);
}
}
// B is already transposed.
@ -188,9 +206,9 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r),
A.Cols()));
const float dot =
Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols());
C_row[c] = hwy::ConvertScalarTo<TC>(add + scale * dot);
}
}
});
@ -279,6 +297,9 @@ void TestTiny() {
for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += max_packages * 4) {
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
}
}
}
@ -334,6 +355,10 @@ void TestAllMatMul() {
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
// Non-vector-multiple K.
TestMatMul<F32, BF16>(128, 258, 128, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16>(128, 258, 128, /*add=*/true, env, __LINE__);
// minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);