mirror of https://github.com/google/gemma.cpp.git
f32 LoopKC: 1.37x(M=512), 1.19(M=128) single-K F32,BF16 matmul speedup on SKX
Add a special case for A=F32,B=BF16, used when there is no native bf16 dot product. dot-inl: ensure bf16,f32 and f32,bf16 both get promoted to float before f64 summation matmul.cc: update autotuning to reflect actual A size matmul_test: add all combinations of bf16/f32, report all results, not just first difference, check non-vector-aligned K PiperOrigin-RevId: 800487817
This commit is contained in:
parent
98ddc166db
commit
31c09cca4c
|
|
@ -85,15 +85,21 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
|
|||
row[c] = f;
|
||||
}
|
||||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
MakeSpan(compressed.Row(r), compressed.Cols()),
|
||||
MakeSpan(compressed.Row(r), extents.cols),
|
||||
/*packed_ofs=*/0);
|
||||
|
||||
// MatMul requires that A's padding be zero-initialized.
|
||||
hwy::ZeroBytes(
|
||||
compressed.Row(r) + extents.cols,
|
||||
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
|
||||
});
|
||||
|
||||
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
|
||||
return compressed;
|
||||
}
|
||||
|
||||
// Same, but `extents` describes the transposed matrix.
|
||||
// Same, but `extents` describes the transposed matrix and the computation of
|
||||
// `f` swaps `r` and `c`.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
const Allocator& allocator,
|
||||
|
|
@ -112,8 +118,13 @@ MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
|||
row[c] = f;
|
||||
}
|
||||
Compress(raw.Row(r), raw.Cols(), ws.tls[thread],
|
||||
MakeSpan(compressed.Row(r), compressed.Cols()),
|
||||
MakeSpan(compressed.Row(r), extents.cols),
|
||||
/*packed_ofs=*/0);
|
||||
|
||||
// MatMul requires that B's padding be zero-initialized.
|
||||
hwy::ZeroBytes(
|
||||
compressed.Row(r) + extents.cols,
|
||||
(compressed.Stride() - extents.cols) * compressed.ElementBytes());
|
||||
});
|
||||
|
||||
// Arbitrary value, different from 1, must match `GenerateMat`.
|
||||
|
|
|
|||
|
|
@ -157,15 +157,16 @@ HWY_MAYBE_UNUSED double ConditionNumber(const VT* HWY_RESTRICT v, size_t num) {
|
|||
// promoted or even DEMOTED to bf16. Runs at about half the speed of f32 FMA.
|
||||
struct DotKernelDouble {
|
||||
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
|
||||
// to be `float` in order to have `Raw = double`. Note that if either type is
|
||||
// smaller than `float`, we may demote the other type from `float` to `BF16`.
|
||||
// to be `float` in order to have `Raw = double`. To avoid loss of accuracy,
|
||||
// if either is float, we decompress both to float, otherwise `BF16`.
|
||||
template <typename VT, typename WT>
|
||||
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
|
||||
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double,
|
||||
hwy::If<IsF32<VT>() || IsF32<WT>(), float, BF16>>;
|
||||
using State = double;
|
||||
|
||||
// Raw = double
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||
HWY_INLINE void Update4(DRaw dd, const VR w0, const VR w1, const VR w2,
|
||||
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
|
||||
const VR w3, const VR v0, const VR v1, const VR v2,
|
||||
const VR v3, VR& sum0, VR& sum1, VR& sum2, VR& sum3,
|
||||
VR&, VR&, VR&, VR&) const {
|
||||
|
|
@ -175,6 +176,41 @@ struct DotKernelDouble {
|
|||
sum3 = hn::MulAdd(w3, v3, sum3);
|
||||
}
|
||||
|
||||
// Raw = float
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
|
||||
const VR w3, const VR v0, const VR v1, const VR v2,
|
||||
const VR v3, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||
VS&, VS&, VS&, VS&) const {
|
||||
const hn::Repartition<double, DRaw> dd;
|
||||
using VD = hn::Vec<decltype(dd)>;
|
||||
VD w0d = hn::PromoteLowerTo(dd, w0);
|
||||
VD w1d = hn::PromoteLowerTo(dd, w1);
|
||||
VD w2d = hn::PromoteLowerTo(dd, w2);
|
||||
VD w3d = hn::PromoteLowerTo(dd, w3);
|
||||
VD v0d = hn::PromoteLowerTo(dd, v0);
|
||||
VD v1d = hn::PromoteLowerTo(dd, v1);
|
||||
VD v2d = hn::PromoteLowerTo(dd, v2);
|
||||
VD v3d = hn::PromoteLowerTo(dd, v3);
|
||||
sum0 = hn::MulAdd(w0d, v0d, sum0);
|
||||
sum1 = hn::MulAdd(w1d, v1d, sum1);
|
||||
sum2 = hn::MulAdd(w2d, v2d, sum2);
|
||||
sum3 = hn::MulAdd(w3d, v3d, sum3);
|
||||
w0d = hn::PromoteUpperTo(dd, w0);
|
||||
w1d = hn::PromoteUpperTo(dd, w1);
|
||||
w2d = hn::PromoteUpperTo(dd, w2);
|
||||
w3d = hn::PromoteUpperTo(dd, w3);
|
||||
v0d = hn::PromoteUpperTo(dd, v0);
|
||||
v1d = hn::PromoteUpperTo(dd, v1);
|
||||
v2d = hn::PromoteUpperTo(dd, v2);
|
||||
v3d = hn::PromoteUpperTo(dd, v3);
|
||||
sum0 = hn::MulAdd(w0d, v0d, sum0);
|
||||
sum1 = hn::MulAdd(w1d, v1d, sum1);
|
||||
sum2 = hn::MulAdd(w2d, v2d, sum2);
|
||||
sum3 = hn::MulAdd(w3d, v3d, sum3);
|
||||
}
|
||||
|
||||
// Raw = BF16
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
|
|
@ -217,11 +253,26 @@ struct DotKernelDouble {
|
|||
|
||||
// Raw = double
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||
HWY_INLINE void Update1(DRaw dd, const VR w0, const VR v0, VR& sum0,
|
||||
HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VR& sum0,
|
||||
VR&) const {
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
}
|
||||
|
||||
// Raw = float
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F32_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
HWY_INLINE void Update1(DRaw dr, const VR w0, const VR v0, VS& sum0,
|
||||
VS&) const {
|
||||
const hn::Repartition<double, DRaw> dd;
|
||||
using VD = hn::Vec<decltype(dd)>;
|
||||
VD w0d = hn::PromoteLowerTo(dd, w0);
|
||||
VD v0d = hn::PromoteLowerTo(dd, v0);
|
||||
sum0 = hn::MulAdd(w0d, v0d, sum0);
|
||||
w0d = hn::PromoteUpperTo(dd, w0);
|
||||
v0d = hn::PromoteUpperTo(dd, v0);
|
||||
sum0 = hn::MulAdd(w0d, v0d, sum0);
|
||||
}
|
||||
|
||||
// Raw = BF16
|
||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||
|
|
|
|||
633
ops/matmul-inl.h
633
ops/matmul-inl.h
|
|
@ -113,7 +113,7 @@ class MMStoreHorizontalSumsIntoC {
|
|||
const size_t row_c, const size_t col_c,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) const {
|
||||
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df);
|
||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||
// log(N) operations for vectors of length N. Because `kNR` == 4, we
|
||||
// instead use `StoreInterleaved4` for a vector length-agnostic
|
||||
|
|
@ -230,7 +230,7 @@ class MMAddHorizontalSumsIntoPartial {
|
|||
const hn::Repartition<double, DF> dd;
|
||||
HWY_ALIGN double buf[16 * hn::MaxLanes(dd)];
|
||||
using VD = hn::Vec<decltype(dd)>;
|
||||
const size_t ND = hn::Lanes(dd);
|
||||
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
|
||||
VD C00 = SumOfPromotedPairs(dd, F00);
|
||||
VD C01 = SumOfPromotedPairs(dd, F01);
|
||||
VD C02 = SumOfPromotedPairs(dd, F02);
|
||||
|
|
@ -351,8 +351,8 @@ class MMKernel {
|
|||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
||||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
template <class Tag, typename TC>
|
||||
static HWY_INLINE void A2C0(const StridedViewBF& A_view,
|
||||
template <class Tag, typename TA, typename TC>
|
||||
static HWY_INLINE void A2C0(const StridedView<TA> A_view, const bool A_padded,
|
||||
const StridedViewBF& B_view, size_t mr,
|
||||
const IndexRange& range_mc, const size_t row_b,
|
||||
size_t kc, Tag tag, const MMArgs& args,
|
||||
|
|
@ -365,8 +365,8 @@ class MMKernel {
|
|||
// M == 1, or x86 with 8 SIMD registers:
|
||||
if (HWY_UNLIKELY(mr == 1)) {
|
||||
for (; imc < mc; ++imc) {
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -375,13 +375,13 @@ class MMKernel {
|
|||
if (HWY_UNLIKELY(mr == 2)) {
|
||||
if (HWY_LIKELY(mc >= 2)) {
|
||||
for (; imc <= mc - 2; imc += 2) {
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
}
|
||||
}
|
||||
if (HWY_UNLIKELY(imc != mc)) {
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
|
@ -389,18 +389,20 @@ class MMKernel {
|
|||
HWY_DASSERT(mr == 4);
|
||||
if (HWY_LIKELY(mc >= 4)) {
|
||||
for (; imc <= mc - 4; imc += 4) {
|
||||
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
LoopKC<4>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag,
|
||||
args, C_rows);
|
||||
}
|
||||
}
|
||||
const size_t remainder_mc = mc - imc;
|
||||
HWY_DASSERT(remainder_mc < 4);
|
||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||
LoopKC<2>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
imc += 2;
|
||||
}
|
||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
|
||||
LoopKC<1>(A_view, A_padded, B_view, row0 + imc, imc, row_b, kc, tag, args,
|
||||
C_rows);
|
||||
imc += 1;
|
||||
}
|
||||
HWY_DASSERT(imc == mc);
|
||||
|
|
@ -423,9 +425,10 @@ class MMKernel {
|
|||
// `MMAddHorizontalSumsIntoPartial`.
|
||||
template <class DBF, class VBF = hn::Vec<DBF>,
|
||||
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void ElementwiseMulAcc(DBF dbf, VBF a, VBF b0, VBF b1,
|
||||
VBF b2, VBF b3, VF& C0, VF& C1,
|
||||
VF& C2, VF& C3) {
|
||||
static HWY_INLINE void ElementwiseMulAccNativeBF(DBF dbf, VBF a, VBF b0,
|
||||
VBF b1, VBF b2, VBF b3,
|
||||
VF& C0, VF& C1, VF& C2,
|
||||
VF& C3) {
|
||||
// This handles a single row of A, so the horizontal sums of `C0..3` are the
|
||||
// (partial) dot products for 4 consecutive values in one row of C.
|
||||
static_assert(kNR == 4);
|
||||
|
|
@ -443,16 +446,17 @@ class MMKernel {
|
|||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
}
|
||||
|
||||
// Like `ElementwiseMulAcc`, but splits BF16 inputs into odd and even f32
|
||||
// for use with FMA. Also handles two rows at a time to hide the FMA latency
|
||||
// (we assume 4 cycles and dual-issue) before writing `C00` again.
|
||||
// Like `ElementwiseMulAccNativeBF`, but splits BF16 inputs into odd and even
|
||||
// f32 for use with FMA. Also handles two rows at a time to hide the FMA
|
||||
// latency (we assume 4 cycles and dual-issue) before writing `C00` again.
|
||||
template <class DBF, class VBF = hn::Vec<DBF>,
|
||||
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void ElementwiseMulAcc2(DBF dbf, VBF a0, VBF a1, VF b0o,
|
||||
VF b0e, VF b1o, VF b1e, VF b2o,
|
||||
VF b2e, VF b3o, VF b3e, VF& C00,
|
||||
VF& C01, VF& C02, VF& C03, VF& C10,
|
||||
VF& C11, VF& C12, VF& C13) {
|
||||
static HWY_INLINE void ElementwiseMulAccEmuBF(DBF dbf, VBF a0, VBF a1, VF b0o,
|
||||
VF b0e, VF b1o, VF b1e, VF b2o,
|
||||
VF b2e, VF b3o, VF b3e, VF& C00,
|
||||
VF& C01, VF& C02, VF& C03,
|
||||
VF& C10, VF& C11, VF& C12,
|
||||
VF& C13) {
|
||||
const DF df;
|
||||
HWY_DASSERT(!HWY_NATIVE_DOT_BF16);
|
||||
// Avoid `ReorderWidenMulAccumulate` because it requires extra adds for
|
||||
|
|
@ -491,20 +495,36 @@ class MMKernel {
|
|||
}
|
||||
}
|
||||
|
||||
// Innermost loop over `kc` columns (typically 1024-4096) in steps of one
|
||||
// vector, for `kRowsAC` rows of `A_view` from range_mc-relative `imc` and
|
||||
// `B_view` from row 0 (both at column 0). Updates a `kRowsAC x kNR` tile
|
||||
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
|
||||
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
||||
// NUQ and requires 2x unrolling, which requires more loads.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
||||
static HWY_INLINE void LoopKC(const StridedViewBF& A_view,
|
||||
// For A=F32, B=BF16 without native BF16 dot product: one lane-crossing
|
||||
// promotion is likely cheaper than AND+SHIFT for promoting odd/even BF.
|
||||
// Caller already promoted B, so all inputs are F32.
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2,
|
||||
VF b3, VF& C0, VF& C1, VF& C2,
|
||||
VF& C3) {
|
||||
HWY_DASSERT(!HWY_NATIVE_DOT_BF16);
|
||||
C0 = hn::MulAdd(a, b0, C0);
|
||||
C1 = hn::MulAdd(a, b1, C1);
|
||||
C2 = hn::MulAdd(a, b2, C2);
|
||||
C3 = hn::MulAdd(a, b3, C3);
|
||||
}
|
||||
|
||||
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
||||
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||
// Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`.
|
||||
// `B` is BF16, `A` and `C` can be F32 or BF16.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC>
|
||||
static HWY_INLINE void LoopKC(const StridedView<TA> A_view,
|
||||
const bool A_padded,
|
||||
const StridedViewBF& B_view, size_t row_ac,
|
||||
size_t imc, size_t col_c, size_t kc, Tag tag,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag<TA>());
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
HWY_DASSERT(kRowsAC <= kMaxMR);
|
||||
HWY_DASSERT(col_c % kNR == 0);
|
||||
|
|
@ -512,30 +532,36 @@ class MMKernel {
|
|||
|
||||
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
|
||||
static_assert(kNR == 4);
|
||||
const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
|
||||
const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
|
||||
const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
|
||||
const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
|
||||
const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
|
||||
const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
|
||||
const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
|
||||
const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
|
||||
const BF16* HWY_RESTRICT br0 = B_view.Row(0);
|
||||
const BF16* HWY_RESTRICT br1 = B_view.Row(1);
|
||||
const BF16* HWY_RESTRICT br2 = B_view.Row(2);
|
||||
const BF16* HWY_RESTRICT br3 = B_view.Row(3);
|
||||
|
||||
// Ensure `A` and `B` were zero-padded by `DecompressAndZeroPad`.
|
||||
// Ensure `A` and `B` were zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
// Only check if `A` is padded, i.e. not packed.
|
||||
if (A_padded) {
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NA); ++i) {
|
||||
{
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// B is unconditionally zero-padded by `DecompressAndZeroPad`.
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||
{
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar0[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar1[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar2[i]) == 0.0f);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(ar3[i]) == 0.0f);
|
||||
}
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br0[i]) == 0.0f);
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br1[i]) == 0.0f);
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(br2[i]) == 0.0f);
|
||||
|
|
@ -553,60 +579,287 @@ class MMKernel {
|
|||
C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df),
|
||||
C33 = hn::Zero(df);
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t ikc = 0; ikc < kc; ikc += NBF) {
|
||||
size_t ikc = 0;
|
||||
// The loop step is always NBF: for non-native BF16 with TA=F32, this
|
||||
// entails 2x unrolling, which helps a little.
|
||||
const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
|
||||
// If A is packed (not padded), we have to check for remainders. Otherwise,
|
||||
// we only run the main loop because A's padding is zero-initialized by
|
||||
// `ZeroInit` or weights.cc.
|
||||
const size_t kc_end = A_padded ? hwy::RoundUpTo(kc, kc_step) : kc;
|
||||
if (kc_end >= kc_step) {
|
||||
HWY_UNROLL(1)
|
||||
for (; ikc <= kc_end - kc_step; ikc += kc_step) {
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
|
||||
// Should only get here if `A` is BF16, otherwise `DecompressA` would
|
||||
// convert to BF16 and `A_view` points to that.
|
||||
HWY_DASSERT(IsBF16<TA>());
|
||||
|
||||
{
|
||||
const VBF a0 = hn::Load(dbf, ar0 + ikc);
|
||||
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VBF a1 = hn::Load(dbf, ar1 + ikc);
|
||||
ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::Load(dbf, ar2 + ikc);
|
||||
ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VBF a3 = hn::Load(dbf, ar3 + ikc);
|
||||
ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
} else { // !HWY_NATIVE_DOT_BF16
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// When both are BF16, it is better to load promote odd/even,
|
||||
// because lane-crossing promotion for both might be bottlenecked on
|
||||
// shuffles.
|
||||
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
|
||||
{
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
b0e = hn::PromoteEvenTo(df, b0);
|
||||
b1e = hn::PromoteEvenTo(df, b1);
|
||||
b2e = hn::PromoteEvenTo(df, b2);
|
||||
b3e = hn::PromoteEvenTo(df, b3);
|
||||
b0o = FastPromoteOddTo(df, b0);
|
||||
b1o = FastPromoteOddTo(df, b1);
|
||||
b2o = FastPromoteOddTo(df, b2);
|
||||
b3o = FastPromoteOddTo(df, b3);
|
||||
}
|
||||
|
||||
// Two rows at a time so we have 8 separate dependency chains,
|
||||
// sufficient for IPC=2 and 4-cycle latency.
|
||||
{
|
||||
const VBF a0 = hn::Load(dbf, ar0 + ikc);
|
||||
const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
|
||||
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C00, C01, C02, C03, C10, C11,
|
||||
C12, C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::Load(dbf, ar2 + ikc);
|
||||
const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
|
||||
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C20, C21, C22, C23, C30, C31,
|
||||
C32, C33);
|
||||
}
|
||||
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
|
||||
// Full-vector loads are a bit faster on SKX than half + PromoteTo.
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
const VF b00 = hn::PromoteLowerTo(df, b0);
|
||||
const VF b10 = hn::PromoteLowerTo(df, b1);
|
||||
const VF b20 = hn::PromoteLowerTo(df, b2);
|
||||
const VF b30 = hn::PromoteLowerTo(df, b3);
|
||||
const VF b01 = hn::PromoteUpperTo(df, b0);
|
||||
const VF b11 = hn::PromoteUpperTo(df, b1);
|
||||
const VF b21 = hn::PromoteUpperTo(df, b2);
|
||||
const VF b31 = hn::PromoteUpperTo(df, b3);
|
||||
|
||||
{
|
||||
const VF a00 = hn::Load(df, ar0 + ikc);
|
||||
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a10 = hn::Load(df, ar1 + ikc);
|
||||
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
// C00 is ready again. On SKX, this interleaved unrolling is faster
|
||||
// than consuming all `b*1` at the end of the loop.
|
||||
{
|
||||
const VF a01 = hn::Load(df, ar0 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a11 = hn::Load(df, ar1 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a20 = hn::Load(df, ar2 + ikc);
|
||||
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a30 = hn::Load(df, ar3 + ikc);
|
||||
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a21 = hn::Load(df, ar2 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a31 = hn::Load(df, ar3 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We want the number of actual valid kc, but we may already be beyond `kc`.
|
||||
const size_t remaining_kc = ikc >= kc ? 0 : kc - ikc;
|
||||
HWY_DASSERT(remaining_kc < kc_step);
|
||||
HWY_DASSERT((remaining_kc == 0) == (A_padded || kc % kc_step == 0));
|
||||
// Last iteration: B is padded but A is not; guard its loads.
|
||||
if (HWY_UNLIKELY(remaining_kc != 0)) {
|
||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
|
||||
// Should only get here if `A` is BF16, otherwise `DecompressA` would
|
||||
// convert to BF16 and `A_view` points to that.
|
||||
HWY_DASSERT(IsBF16<TA>());
|
||||
|
||||
{
|
||||
const VBF a0 = hn::Load(dbf, ar0 + ikc);
|
||||
ElementwiseMulAcc(dbf, a0, b0, b1, b2, b3, C00, C01, C02, C03);
|
||||
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
|
||||
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VBF a1 = hn::Load(dbf, ar1 + ikc);
|
||||
ElementwiseMulAcc(dbf, a1, b0, b1, b2, b3, C10, C11, C12, C13);
|
||||
const VBF a1 = hn::LoadN(dbf, ar1 + ikc, remaining_kc);
|
||||
ElementwiseMulAccNativeBF(dbf, a1, b0, b1, b2, b3, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::Load(dbf, ar2 + ikc);
|
||||
ElementwiseMulAcc(dbf, a2, b0, b1, b2, b3, C20, C21, C22, C23);
|
||||
const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
|
||||
ElementwiseMulAccNativeBF(dbf, a2, b0, b1, b2, b3, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VBF a3 = hn::Load(dbf, ar3 + ikc);
|
||||
ElementwiseMulAcc(dbf, a3, b0, b1, b2, b3, C30, C31, C32, C33);
|
||||
const VBF a3 = hn::LoadN(dbf, ar3 + ikc, remaining_kc);
|
||||
ElementwiseMulAccNativeBF(dbf, a3, b0, b1, b2, b3, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
} else {
|
||||
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
|
||||
{
|
||||
} else { // !HWY_NATIVE_DOT_BF16
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// When both are BF16, it is better to load promote odd/even, because
|
||||
// lane-crossing promotion for both might be bottlenecked on shuffles.
|
||||
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
|
||||
{
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
b0e = hn::PromoteEvenTo(df, b0);
|
||||
b1e = hn::PromoteEvenTo(df, b1);
|
||||
b2e = hn::PromoteEvenTo(df, b2);
|
||||
b3e = hn::PromoteEvenTo(df, b3);
|
||||
b0o = FastPromoteOddTo(df, b0);
|
||||
b1o = FastPromoteOddTo(df, b1);
|
||||
b2o = FastPromoteOddTo(df, b2);
|
||||
b3o = FastPromoteOddTo(df, b3);
|
||||
}
|
||||
|
||||
// Two rows at a time so we have 8 separate dependency chains,
|
||||
// sufficient for IPC=2 and 4-cycle latency.
|
||||
{
|
||||
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
|
||||
const VBF a1 =
|
||||
kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0;
|
||||
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C00, C01, C02, C03, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
|
||||
const VBF a3 =
|
||||
kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2;
|
||||
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C20, C21, C22, C23, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
} else { // IsF32<TA>(): promote half-B to F32, F32*F32.
|
||||
const VBF b0 = hn::Load(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::Load(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::Load(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::Load(dbf, br3 + ikc);
|
||||
b0e = hn::PromoteEvenTo(df, b0);
|
||||
b1e = hn::PromoteEvenTo(df, b1);
|
||||
b2e = hn::PromoteEvenTo(df, b2);
|
||||
b3e = hn::PromoteEvenTo(df, b3);
|
||||
b0o = FastPromoteOddTo(df, b0);
|
||||
b1o = FastPromoteOddTo(df, b1);
|
||||
b2o = FastPromoteOddTo(df, b2);
|
||||
b3o = FastPromoteOddTo(df, b3);
|
||||
}
|
||||
const VF b00 = hn::PromoteLowerTo(df, b0);
|
||||
const VF b10 = hn::PromoteLowerTo(df, b1);
|
||||
const VF b20 = hn::PromoteLowerTo(df, b2);
|
||||
const VF b30 = hn::PromoteLowerTo(df, b3);
|
||||
const VF b01 = hn::PromoteUpperTo(df, b0);
|
||||
const VF b11 = hn::PromoteUpperTo(df, b1);
|
||||
const VF b21 = hn::PromoteUpperTo(df, b2);
|
||||
const VF b31 = hn::PromoteUpperTo(df, b3);
|
||||
|
||||
{
|
||||
const VBF a0 = hn::Load(dbf, ar0 + ikc);
|
||||
const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
|
||||
ElementwiseMulAcc2(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
|
||||
b3e, C00, C01, C02, C03, C10, C11, C12, C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::Load(dbf, ar2 + ikc);
|
||||
const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
|
||||
ElementwiseMulAcc2(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
|
||||
b3e, C20, C21, C22, C23, C30, C31, C32, C33);
|
||||
const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA;
|
||||
|
||||
{
|
||||
const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
// C00 is ready again. On SKX, this interleaved unrolling is faster
|
||||
// than consuming all `b*1` at the end of the loop.
|
||||
{
|
||||
const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // remaining_kc != 0
|
||||
|
||||
// This is a substantial fraction (about 1/3) of the total time, but is
|
||||
// called frequently, so do not add a profiler zone.
|
||||
|
|
@ -678,7 +931,7 @@ class MMScaleDemoteAdd {
|
|||
const hn::Rebind<TC, decltype(dd)> dc;
|
||||
using VD = hn::Vec<decltype(dd)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t ND = hn::Lanes(dd);
|
||||
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
|
||||
const VD vscale = hn::Set(dd, args.scale);
|
||||
|
||||
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
||||
|
|
@ -796,7 +1049,7 @@ class MMScaleDemoteAdd {
|
|||
const hn::Rebind<TC, decltype(dd)> dc;
|
||||
using VD = hn::Vec<decltype(dd)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t ND = hn::Lanes(dd);
|
||||
HWY_LANES_CONSTEXPR const size_t ND = hn::Lanes(dd);
|
||||
const VD vscale = hn::Set(dd, args.scale);
|
||||
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
||||
TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
|
||||
|
|
@ -858,41 +1111,51 @@ class MMScaleDemoteAdd {
|
|||
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC.
|
||||
// Its member variables avoid long argument lists in Do*().
|
||||
class MMPerPackage {
|
||||
public:
|
||||
// Decompression is only required for F32 A and native BF16 dot products.
|
||||
// If A is already BF16, we can use a view. Padding is not required
|
||||
// because `LoopKC` can handle non-vector multiples. `LoopKC` also contains
|
||||
// a special case for F32 `A` and non-native BF16 dot products.
|
||||
template <typename TA>
|
||||
MMPerPackage(const MatPtrT<TA>& A, const MMArgs& args, const MMConfig& config,
|
||||
static constexpr bool WantDecompressA() {
|
||||
return HWY_NATIVE_DOT_BF16 && IsF32<TA>();
|
||||
}
|
||||
|
||||
public:
|
||||
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
|
||||
size_t pkg_idx, const IndexRange& range_np)
|
||||
: args_(args),
|
||||
pkg_idx_(pkg_idx),
|
||||
// May be overwritten with a view of A, if already BF16.
|
||||
A_(args_.env->storage.A(pkg_idx, A.Extents())),
|
||||
range_np_(range_np),
|
||||
mr_(config.MR()),
|
||||
ranges_mc_(config.RangesOfMC(A.Rows())),
|
||||
ranges_kc_(config.RangesOfKC(A.Cols())),
|
||||
ranges_mc_(config.RangesOfMC(A.rows)),
|
||||
ranges_kc_(config.RangesOfKC(A.cols)),
|
||||
ranges_nc_(config.RangesOfNC(range_np)),
|
||||
order_(config.Order()),
|
||||
inner_tasks_(config.InnerTasks()),
|
||||
out_(config.Out()),
|
||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {
|
||||
A_ = DecompressA(A);
|
||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {}
|
||||
|
||||
// The size of `A` that will actually be used, for purposes of choosing the
|
||||
// autotuning candidates. Keep in sync with the `operator()` logic below.
|
||||
template <typename TA>
|
||||
static constexpr size_t ABytes() {
|
||||
return WantDecompressA<TA>() ? sizeof(BF16) : sizeof(TA);
|
||||
}
|
||||
|
||||
// B is decompressed several call layers lower, but not all member functions
|
||||
// depend on TB, so pass it as an argument instead of templating the class.
|
||||
template <typename TB, typename TC>
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT(B, C_rows);
|
||||
case MMOrder::kNT_K:
|
||||
return DoNT_K(B, C_rows);
|
||||
case MMOrder::kNT_MT:
|
||||
return DoNT_MT(B, C_rows);
|
||||
case MMOrder::kNT_MT_K:
|
||||
return DoNT_MT_K(B, C_rows);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
// B and maybe A are decompressed several call layers lower, but not all
|
||||
// member functions depend on TA/TB, so pass them as an argument instead of
|
||||
// templating the class.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_NOINLINE void operator()(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
if constexpr (WantDecompressA<TA>()) {
|
||||
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
||||
DecompressA(A, A_view);
|
||||
constexpr bool A_padded = true; // MMStorage `pkg_A_` is padded.
|
||||
DispatchOrder(A_view, A_padded, B, C_rows);
|
||||
} else {
|
||||
const bool A_padded = HasPadding(A);
|
||||
DispatchOrder(View(A, 0, 0, A.Cols()), A_padded, B, C_rows);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -909,16 +1172,57 @@ class MMPerPackage {
|
|||
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
|
||||
}
|
||||
|
||||
// Use instead of `MatPtr::IsPacked` because that returns true for single
|
||||
// rows, but we want to know whether there is padding.
|
||||
static bool HasPadding(const MatPtr& mat) {
|
||||
return mat.Stride() > mat.Cols();
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`. Both `A``
|
||||
// and `B` are const, but StridedView is also used for non-const `partial`.
|
||||
template <typename T>
|
||||
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
|
||||
size_t cols) {
|
||||
HWY_DASSERT(c < AB.Cols());
|
||||
HWY_DASSERT(cols <= AB.Cols() - c);
|
||||
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(hn::ScalableTag<T>());
|
||||
(void)N;
|
||||
// If `AB` is padded, then `LoopKC` expects the view is either a vector
|
||||
// multiple, or all columns and thus also padded.
|
||||
HWY_DASSERT(!HasPadding(AB) || (cols % N == 0 || cols == AB.Cols()));
|
||||
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
|
||||
}
|
||||
|
||||
// `TA` is usually BF16, but can be F32 if `!HWY_NATIVE_DOT_BF16`.
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DispatchOrder(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
switch (order_) {
|
||||
case MMOrder::kNT:
|
||||
return DoNT(A, A_padded, B, C_rows);
|
||||
case MMOrder::kNT_K:
|
||||
return DoNT_K(A, A_padded, B, C_rows);
|
||||
case MMOrder::kNT_MT:
|
||||
return DoNT_MT(A, A_padded, B, C_rows);
|
||||
case MMOrder::kNT_MT_K:
|
||||
return DoNT_MT_K(A, A_padded, B, C_rows);
|
||||
default:
|
||||
HWY_UNREACHABLE;
|
||||
}
|
||||
}
|
||||
|
||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_M = ranges_mc_.Range(0);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K);
|
||||
const StridedView<TA> A_view = A.View(range_M.begin(), 0, K);
|
||||
const size_t B_stride =
|
||||
Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_);
|
||||
|
||||
|
|
@ -936,8 +1240,8 @@ class MMPerPackage {
|
|||
row_b += kNR) {
|
||||
StridedViewBF B_view =
|
||||
DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_M, row_b, K,
|
||||
MMSetC(), args_, C_rows);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -945,8 +1249,9 @@ class MMPerPackage {
|
|||
}
|
||||
|
||||
// Single M range, parallel N, sequential K. Fills all of partial.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K");
|
||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||
const IndexRange& range_mc = ranges_mc_.Range(0);
|
||||
|
|
@ -958,8 +1263,8 @@ class MMPerPackage {
|
|||
const IndexRange& range_nc,
|
||||
auto out_tag) HWY_ATTR {
|
||||
const size_t kc = range_kc.Num();
|
||||
const StridedViewBF& A_view =
|
||||
A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const StridedView<TA> A_view =
|
||||
A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const StridedViewBF B_storage_view(
|
||||
B_storage, kc,
|
||||
Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_));
|
||||
|
|
@ -967,8 +1272,8 @@ class MMPerPackage {
|
|||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
|
||||
out_tag, args_, C_rows);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1013,8 +1318,9 @@ class MMPerPackage {
|
|||
|
||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||
// Fills `mc x nc` sections of C directly, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT");
|
||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
|
|
@ -1031,7 +1337,7 @@ class MMPerPackage {
|
|||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(worker, zone, args_);
|
||||
|
||||
const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
const StridedView<TA> A_view = A.View(range_mc.begin(), 0, K);
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const StridedViewBF B_storage_view(B_storage, K, B_stride);
|
||||
|
||||
|
|
@ -1039,8 +1345,8 @@ class MMPerPackage {
|
|||
row_b += kNR) {
|
||||
StridedViewBF B_view =
|
||||
DecompressB(B, row_b, range_K, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||
args_, C_rows);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, K,
|
||||
MMSetC(), args_, C_rows);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -1049,8 +1355,9 @@ class MMPerPackage {
|
|||
|
||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
template <typename TA, typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_MT_K(const StridedView<TA> A, const bool A_padded,
|
||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
|
||||
static const auto fill_zone =
|
||||
args_.env->ctx.profiler.AddZone("MM.NT_MT_K.FillC");
|
||||
|
|
@ -1067,14 +1374,14 @@ class MMPerPackage {
|
|||
const IndexRange& range_nc,
|
||||
auto out_tag) HWY_ATTR {
|
||||
const size_t kc = range_kc.Num();
|
||||
const StridedViewBF& A_view =
|
||||
A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const StridedView<TA> A_view =
|
||||
A.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view);
|
||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||
C_rows);
|
||||
MMKernel::A2C0(A_view, A_padded, B_view, mr_, range_mc, row_b, kc,
|
||||
out_tag, args_, C_rows);
|
||||
}
|
||||
}; // loop_nc
|
||||
args_.env->parallel.ForRangesMC_NC(
|
||||
|
|
@ -1107,17 +1414,16 @@ class MMPerPackage {
|
|||
});
|
||||
}
|
||||
|
||||
// Decompresses all `M x K` from `A` into padded BF16 `A_`. Assumes `TA` is a
|
||||
// seekable type (i.e., not NUQ) so we can use pointer arithmetic.
|
||||
template <typename TA>
|
||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
|
||||
// Decompresses all `M x K` from `A` into padded BF16 `A_view`.
|
||||
HWY_NOINLINE void DoDecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view,
|
||||
MMParA par_a) const {
|
||||
const IndexRange all_M(0, A.Rows());
|
||||
const IndexRange all_K(0, A.Cols());
|
||||
HWY_DASSERT(all_K.Num() == A_.Cols());
|
||||
HWY_DASSERT(all_K.Num() == A_view.Cols());
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const size_t NBF = hn::Lanes(dbf);
|
||||
static_assert(hwy::IsSameEither<TA, BF16, float>(), "Can seek");
|
||||
|
||||
static const auto zone = args_.env->ctx.profiler.AddZone("MM.DecompressA");
|
||||
|
||||
|
|
@ -1133,8 +1439,9 @@ class MMPerPackage {
|
|||
// otherwise `DecompressAndZeroPad` overwrites neighbors.
|
||||
HWY_DASSERT(cols % NBF == 0 || range_K.end() == A.Cols());
|
||||
for (size_t row_a : range_M) {
|
||||
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
|
||||
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
|
||||
const PackedSpan<const float> from =
|
||||
MakeSpan(A.Row(row_a) + col0, cols);
|
||||
BF16* HWY_RESTRICT to = A_view.Row(row_a) + col0;
|
||||
DecompressAndZeroPad(dbf, from, 0, to, cols);
|
||||
// Verify that we zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
|
|
@ -1175,23 +1482,12 @@ class MMPerPackage {
|
|||
}
|
||||
|
||||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename TA>
|
||||
HWY_INLINE StridedViewBF DecompressA(const MatPtrT<TA>& A) const {
|
||||
HWY_INLINE void DecompressA(const MatPtrT<float>& A,
|
||||
const StridedViewBF A_view) const {
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
// Only if vector multiple and padded (see `DoDecompressA`).
|
||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) {
|
||||
// Const, but cast because StridedView is also used for `partial` which
|
||||
// is non-const.
|
||||
return StridedViewBF(const_cast<TA*>(A.Row(0)), A.Cols(), A.Stride());
|
||||
}
|
||||
}
|
||||
|
||||
if (HWY_LIKELY(autotune.Best())) {
|
||||
DoDecompressA(A, *autotune.Best());
|
||||
return A_;
|
||||
return DoDecompressA(A, A_view, *autotune.Best());
|
||||
}
|
||||
|
||||
// First call: generate candidates.
|
||||
|
|
@ -1204,7 +1500,7 @@ class MMPerPackage {
|
|||
|
||||
const MMParA& par_a = autotune.NextConfig();
|
||||
const uint64_t t0 = hwy::timer::Start();
|
||||
DoDecompressA(A, par_a);
|
||||
DoDecompressA(A, A_view, par_a);
|
||||
const uint64_t t1 =
|
||||
args_.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||
const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0);
|
||||
|
|
@ -1213,7 +1509,6 @@ class MMPerPackage {
|
|||
static_cast<double>(min_elapsed) /
|
||||
hwy::platform::InvariantTicksPerSecond() * 1E6);
|
||||
}
|
||||
return A_;
|
||||
}
|
||||
|
||||
// Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0,
|
||||
|
|
@ -1223,12 +1518,17 @@ class MMPerPackage {
|
|||
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const StridedViewBF& B_view) const {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
// View() is safe if vector multiple, or padded: for the latter, `ZeroInit`
|
||||
// and weights.cc zero-initialize the padding.
|
||||
if constexpr (hwy::IsSame<TB, BF16>()) {
|
||||
return StridedViewBF(const_cast<BF16*>(B.Row(row_b)) + range_kc.begin(),
|
||||
range_kc.Num(), B.Stride());
|
||||
if (B.Cols() % NBF == 0 || HasPadding(B)) {
|
||||
return View(B, row_b, range_kc.begin(), range_kc.Num());
|
||||
}
|
||||
}
|
||||
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
const PackedSpan<const TB> B_span = B.PaddedSpan();
|
||||
|
||||
const size_t kc = range_kc.Num();
|
||||
|
|
@ -1240,7 +1540,7 @@ class MMPerPackage {
|
|||
DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc);
|
||||
// Verify that we zero-padded.
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, hn::Lanes(dbf)); ++i) {
|
||||
for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) {
|
||||
HWY_DASSERT(hwy::ConvertScalarTo<float>(to[i]) == 0.0f);
|
||||
}
|
||||
}
|
||||
|
|
@ -1250,7 +1550,6 @@ class MMPerPackage {
|
|||
|
||||
const MMArgs args_; // copy for locality
|
||||
const size_t pkg_idx_;
|
||||
StridedViewBF A_; // view into A or pkg_A_, both of which are padded.
|
||||
|
||||
const IndexRange range_np_;
|
||||
// From MMConfig:
|
||||
|
|
@ -1293,13 +1592,14 @@ struct MMImpl {
|
|||
MMZone mm_zone;
|
||||
mm_zone.MaybeEnter(pkg_idx, zone, args);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B,
|
||||
C_rows);
|
||||
});
|
||||
} else {
|
||||
const size_t pkg_idx = 0;
|
||||
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
|
||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
|
||||
MMPerPackage(A.Extents(), args, config, pkg_idx, range_np)(A, B, C_rows);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -1310,7 +1610,7 @@ struct MMImpl {
|
|||
// `K = B.Cols()`, which must match `A.Cols()`, is the number
|
||||
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
|
||||
// are no other restrictions on shape, though performance is better when `M % 4
|
||||
// == 0` or `M <= 4`, and when A is padded (`!A.IsPacked()`).
|
||||
// == 0` or `M <= 4`, and when A is padded (Stride() > Cols()).
|
||||
//
|
||||
// NOTE: if A and/or B are BF16 and padded, the interval `[Cols(),
|
||||
// hwy::RoundUpTo(Cols(), hn::Lanes(dbf))` must be zero-initialized to match
|
||||
|
|
@ -1376,8 +1676,9 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
HWY_ASSERT(N <= MMStorage::kMaxN);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
|
||||
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
|
||||
kNR, per_key.ranges_np, env.print_config));
|
||||
tuner.SetCandidates(
|
||||
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC),
|
||||
kMaxMR, kNR, per_key.ranges_np, env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
|
|
|
|||
|
|
@ -64,19 +64,21 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
|||
class GenerateCandidates {
|
||||
public:
|
||||
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np, bool print_config)
|
||||
size_t sizeof_TA, size_t sizeof_TC, size_t max_mr,
|
||||
size_t nr, const IndexRangePartition& ranges_np,
|
||||
bool print_config)
|
||||
: allocator_(allocator),
|
||||
M_(M),
|
||||
K_(K),
|
||||
N_(N),
|
||||
sizeof_TA_(sizeof_TA),
|
||||
sizeof_TC_(sizeof_TC),
|
||||
max_mr_(max_mr),
|
||||
nr_(nr),
|
||||
// These influence kc/nc, but are also stored in `MMConfig` for
|
||||
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||
// is likely still in L1, but we expect K > 1000 and might as well round
|
||||
// up to the line size.
|
||||
// up to the line size. Use BF16, not sizeof_TA, because B is BF16.
|
||||
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
|
||||
nc_multiple_(allocator.StepBytes() / sizeof_TC),
|
||||
ranges_np_(ranges_np),
|
||||
|
|
@ -176,8 +178,9 @@ class GenerateCandidates {
|
|||
// subtract the output and buf, and allow using more than the actual L1
|
||||
// size. This results in an overestimate, and the loop below will propose
|
||||
// the next few smaller values for the autotuner to evaluate.
|
||||
const size_t bytes_ab = allocator_.L1Bytes() * 3;
|
||||
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
|
||||
const size_t bytes_ab =
|
||||
allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream));
|
||||
const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16);
|
||||
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
|
||||
kc_max =
|
||||
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
|
||||
|
|
@ -224,7 +227,7 @@ class GenerateCandidates {
|
|||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
||||
// partial.
|
||||
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
|
||||
const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes();
|
||||
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
|
||||
HWY_DASSERT(mc_max != 0);
|
||||
|
|
@ -359,6 +362,7 @@ class GenerateCandidates {
|
|||
const size_t M_;
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
const size_t sizeof_TA_;
|
||||
const size_t sizeof_TC_;
|
||||
|
||||
const size_t max_mr_;
|
||||
|
|
@ -376,12 +380,12 @@ class GenerateCandidates {
|
|||
|
||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
size_t K, size_t N, size_t sizeof_TA,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
bool print_config) {
|
||||
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
|
||||
ranges_np, print_config)();
|
||||
return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr,
|
||||
nr, ranges_np, print_config)();
|
||||
}
|
||||
|
||||
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||
|
|
|
|||
|
|
@ -281,7 +281,9 @@ class MMStorage {
|
|||
BindC(partial_storage_, parallel);
|
||||
}
|
||||
|
||||
// Returns per-package matrix view.
|
||||
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is
|
||||
// faster than on-the-fly when native BF16 is available: it only happens once,
|
||||
// not per B tile row, and the cache footprint is smaller.
|
||||
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
|
|
@ -475,8 +477,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
|||
#pragma pack(pop)
|
||||
|
||||
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
size_t K, size_t N, size_t sizeof_TA,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
bool print_config);
|
||||
|
||||
|
|
|
|||
|
|
@ -118,17 +118,26 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
|
||||
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
|
||||
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||
// Dot() uses double-precision summation.
|
||||
double tolerance = 12 * norm * eps_f32;
|
||||
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
||||
// tolerance there.
|
||||
if (IsF32<TA>() && IsF32<TB>()) {
|
||||
tolerance += 4 * max_abs * eps_bf16;
|
||||
// If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to
|
||||
// BF16, so add extra tolerance.
|
||||
if (IsF32<TB>()) {
|
||||
tolerance += 2 * max_abs * eps_bf16;
|
||||
}
|
||||
|
||||
if (tolerance > 500.0) {
|
||||
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
|
||||
}
|
||||
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
|
||||
const double rel_tolerance =
|
||||
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
|
||||
|
||||
double max_rel = 0.0;
|
||||
size_t worst_r = 0;
|
||||
size_t worst_c = 0;
|
||||
double worst_actual = 0.0;
|
||||
double worst_expected = 0.0;
|
||||
size_t num_outside = 0;
|
||||
for (size_t r = 0; r < A.Rows(); r++) {
|
||||
const float* expected_row = c_slow_batch.Row(r);
|
||||
const float* actual_row = c_batch.Row(r);
|
||||
|
|
@ -143,15 +152,24 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
const double min = HWY_MIN(expected_value, actual_value);
|
||||
const double rel = max / HWY_MAX(min, 1E-6);
|
||||
if (rel > max_rel) {
|
||||
hwy::Abort(__FILE__, line,
|
||||
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
||||
"tolerance %f rel %E max_rel %E\n",
|
||||
r, c, expected_value, actual_value, norm, max_abs,
|
||||
tolerance, rel, max_rel);
|
||||
worst_expected = expected_value;
|
||||
worst_actual = actual_value;
|
||||
worst_r = r;
|
||||
worst_c = c;
|
||||
max_rel = rel;
|
||||
++num_outside;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_rel > rel_tolerance) {
|
||||
hwy::Abort(__FILE__, line,
|
||||
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
||||
"tolerance %f rel %E max_rel %E num_outside %zu\n",
|
||||
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
|
||||
tolerance, max_rel, rel_tolerance, num_outside);
|
||||
}
|
||||
}
|
||||
|
||||
// B is already transposed.
|
||||
|
|
@ -188,9 +206,9 @@ HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
|
|||
TC* HWY_RESTRICT C_row = C.Row(r);
|
||||
for (size_t c : cols_c) {
|
||||
const float add = add_row ? add_row[c] : 0.0f;
|
||||
C_row[c] = hwy::ConvertScalarTo<TC>(
|
||||
add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r),
|
||||
A.Cols()));
|
||||
const float dot =
|
||||
Dot(df, b_span, c * B.Stride(), A.Row(r), A.Cols());
|
||||
C_row[c] = hwy::ConvertScalarTo<TC>(add + scale * dot);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -279,6 +297,9 @@ void TestTiny() {
|
|||
for (size_t K = 1; K <= 64; K *= 2) {
|
||||
for (size_t N = 4; N <= 64; N += max_packages * 4) {
|
||||
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
|
||||
TestMatMul<BF16, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
|
||||
TestMatMul<F32, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
|
||||
TestMatMul<BF16, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -334,6 +355,10 @@ void TestAllMatMul() {
|
|||
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||
|
||||
// Non-vector-multiple K.
|
||||
TestMatMul<F32, BF16>(128, 258, 128, /*add=*/true, env, __LINE__);
|
||||
TestMatMul<BF16, BF16>(128, 258, 128, /*add=*/true, env, __LINE__);
|
||||
|
||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||
TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
|
||||
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);
|
||||
|
|
|
|||
Loading…
Reference in New Issue