Support bf16 output of Matmul

Adds Stride to ConstMat, to support decompression of C output for test
matmul_test: add line numbers to output
Also ignore "N is not a multiple of nc" when N==nc
PiperOrigin-RevId: 731096662
This commit is contained in:
Jan Wassenberg 2025-02-25 17:52:50 -08:00 committed by Copybara-Service
parent b3b4b9f92f
commit 2bdf26d81d
6 changed files with 353 additions and 262 deletions

View File

@ -734,7 +734,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
// Hidden layer -> output layer. // Hidden layer -> output layer.
auto activations_mat = MakeConstMat( auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim)); hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim),
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
} }
@ -773,8 +774,9 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
multiplier.Row(0), ff_hidden_dim * num_interleaved); multiplier.Row(0), ff_hidden_dim * num_interleaved);
// Hidden layer -> output layer. // Hidden layer -> output layer.
auto activations_mat = MakeConstMat( auto activations_mat = MakeConstMat(hidden_activations.Row(0),
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim)); Extents2D(num_interleaved, ff_hidden_dim),
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
} }

View File

@ -133,21 +133,22 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// Generates inputs and prints observed throughput of MatMul. // Generates inputs and prints observed throughput of MatMul.
// M = A rows, K = A cols, N = C cols. // M = A rows, K = A cols, N = C cols.
template <typename MatTA, typename MatTB = MatTA> template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0); hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
if (env.print_config || env.print_measurement) { if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N, fprintf(stderr,
add, TypeName<MatTA>(), TypeName<MatTB>()); "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", //
M, K, N, add, TypeName<TA>(), TypeName<TB>(), TypeName<TC>());
const Extents2D A_extents(M, K); const Extents2D A_extents(M, K);
const Extents2D B_extents(N, K); // already transposed const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N); const Extents2D C_extents(M, N);
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents); RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents); RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
std::unique_ptr<MatStorageT<float>> add_storage; std::unique_ptr<MatStorageT<float>> add_storage;
if (add) { if (add) {
@ -156,14 +157,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
add_storage->set_scale(1.0f); add_storage->set_scale(1.0f);
} }
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool); MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool); MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
HWY_ASSERT(a && b_trans); HWY_ASSERT(a && b_trans);
const auto A = ConstMatFromWeights(*a); const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans); const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr; const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C = RowPtrFromBatch(c_batch); const RowPtr<TC> C = RowPtrFromBatch(c_batch);
// Fewer reps for large batch sizes, which take longer. // Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12; const size_t num_samples = M < 32 ? 20 : 12;
@ -173,7 +174,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and // Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling // spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op. // BindB/C if there is a single package: they will be a no-op.
BindB(B_extents.rows, B, env.parallel); BindB(B_extents.rows, sizeof(TC), B, env.parallel);
BindC(A_extents.rows, C, env.parallel); BindC(A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
@ -191,7 +192,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
per_key = MatMul(A, B, add_row, env, C); per_key = MatMul(A, B, add_row, env, C);
const double t1 = hwy::platform::Now(); const double t1 = hwy::platform::Now();
double elapsed = t1 - t0; double elapsed = t1 - t0;
keep += C.Row(0)[hwy::Unpredictable1()]; keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);
// Only record times after autotuning finished. // Only record times after autotuning finished.
if (per_key->autotune.Best()) times.push_back(elapsed); if (per_key->autotune.Best()) times.push_back(elapsed);
@ -229,8 +230,8 @@ void BenchAllMatMul() {
for (size_t batch_size : {1, 4, 128, 512}) { for (size_t batch_size : {1, 4, 128, 512}) {
constexpr bool kAdd = false; constexpr bool kAdd = false;
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env); BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env); BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env);
} }
PROFILER_PRINT_RESULTS(); PROFILER_PRINT_RESULTS();

View File

@ -46,7 +46,7 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register.
template <class DF, class DBF = hn::Repartition<hwy::bfloat16_t, DF>> template <class DF, class DBF = hn::Repartition<BF16, DF>>
static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) { static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
// Promoting odd means clearing the lower 16 bits. Doing this via AND // Promoting odd means clearing the lower 16 bits. Doing this via AND
// requires a second input vector, which we prefer to avoid due to high // requires a second input vector, which we prefer to avoid due to high
@ -70,6 +70,16 @@ static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
#endif #endif
} }
// Converts from float intermediate to MatMul output type `TC`.
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_F32_D(DC)>
hn::Vec<DC> TCFromF32(DC /*dc*/, hn::Vec<DF> vf) {
return vf;
}
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_BF16_D(DC)>
hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
return hn::DemoteTo(dc, vf);
}
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one // Tag classes, passed to `MMKernel::A2C0` to choose between writing one
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the // (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
// first kc result to partial, or accumulating the next kc result into partial // first kc result to partial, or accumulating the next kc result into partial
@ -93,14 +103,14 @@ class MMStoreHorizontalSumsIntoC {
// Thus we compute the horizontal sums of each `Crc`. The elements may be // Thus we compute the horizontal sums of each `Crc`. The elements may be
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
// this does not change their horizontal sum. // this does not change their horizontal sum.
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>, typename TC>
HWY_INLINE void operator()(DF df, // HWY_INLINE void operator()(DF df, //
VF C00, VF C01, VF C02, VF C03, // VF C00, VF C01, VF C02, VF C03, //
VF C10, VF C11, VF C12, VF C13, // VF C10, VF C11, VF C12, VF C13, //
VF C20, VF C21, VF C22, VF C23, // VF C20, VF C21, VF C22, VF C23, //
VF C30, VF C31, VF C32, VF C33, // VF C30, VF C31, VF C32, VF C33, //
const size_t row_c, const size_t col_c, const size_t row_c, const size_t col_c,
const MMArgs& args) const { const MMArgs& args, const RowPtr<TC>& C) const {
float buf[16 * hn::MaxLanes(df)]; float buf[16 * hn::MaxLanes(df)];
const size_t N = hn::Lanes(df); const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing // Horizontal reductions (`ReduceSum`) are rather expensive, entailing
@ -136,10 +146,10 @@ class MMStoreHorizontalSumsIntoC {
if constexpr (kAdd) { if constexpr (kAdd) {
vadd = hn::Load(d4, args.add + col_c); vadd = hn::Load(d4, args.add + col_c);
} }
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, args.C, row_c, col_c); MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, args.C, row_c, col_c); MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, args.C, row_c, col_c); MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, args.C, row_c, col_c); MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c);
} }
private: private:
@ -155,34 +165,36 @@ class MMStoreHorizontalSumsIntoC {
} }
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
template <size_t kRow, class D4, class V4 = hn::Vec<D4>> template <size_t kRow, class DF4, class VF4 = hn::Vec<DF4>>
static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N, static HWY_INLINE VF4 MaybeLoad(DF4 df4, size_t N,
const float* HWY_RESTRICT buf) { const float* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) { if constexpr (kRow < kRowsAC) {
return hn::Load(d4, buf + 4 * kRow * N); return hn::Load(df4, buf + 4 * kRow * N);
} else { } else {
return hn::Zero(d4); return hn::Zero(df4);
} }
} }
template <size_t kRow, class D4, class V4 = hn::Vec<D4>> template <size_t kRow, class DF4, class VF4 = hn::Vec<DF4>>
static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum, static HWY_INLINE VF4 MaybeAdd(DF4 df4, size_t N, VF4 sum,
const float* HWY_RESTRICT buf) { const float* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) { if constexpr (kRow < kRowsAC) {
return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N)); return hn::Add(sum, hn::Load(df4, buf + 4 * kRow * N));
} else { } else {
return sum; return sum;
} }
} }
template <size_t kRow, class D4, class V4 = hn::Vec<D4>> template <size_t kRow, typename TC, class DF4, class VF4 = hn::Vec<DF4>>
static HWY_INLINE void MaybeScaleAndStore(D4 d4, V4 sum, V4 vscale, V4 vadd, static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
const RowPtrF& C, VF4 vadd, const RowPtr<TC>& C,
const size_t row_c, const size_t row_c,
const size_t col_c) { const size_t col_c) {
if constexpr (kRow < kRowsAC) { if constexpr (kRow < kRowsAC) {
float* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c; TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
hn::Store(hn::MulAdd(sum, vscale, vadd), d4, pos); const hn::Rebind<TC, DF4> dc4;
const VF4 out = hn::MulAdd(sum, vscale, vadd);
hn::Store(TCFromF32(dc4, out), dc4, pos);
} }
} }
}; // MMStoreHorizontalSumsIntoC }; // MMStoreHorizontalSumsIntoC
@ -340,14 +352,14 @@ class MMKernel {
// or less on ISAs with fewer registers, or for the last few rows of A. // or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4; static constexpr size_t kMaxMR = 4;
// Calls `LoopOverKC` for each of `mc` rows of A in steps of `mr`. `A_view` // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag> template <class Tag, typename TC>
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view, static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t mr, const IndexRange& range_mc, size_t mr, const IndexRange& range_mc,
const size_t row_b, size_t kc, Tag tag, const size_t row_b, size_t kc, Tag tag,
const MMArgs& args) { const MMArgs& args, const RowPtr<TC>& C) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR); HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t row0 = range_mc.begin(); const size_t row0 = range_mc.begin();
const size_t mc = range_mc.Num(); const size_t mc = range_mc.Num();
@ -356,7 +368,7 @@ class MMKernel {
// M == 1, or x86 with 8 SIMD registers: // M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) { if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) { for (; imc < mc; ++imc) {
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
} }
return; return;
} }
@ -365,11 +377,11 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) { if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) { if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) { for (; imc <= mc - 2; imc += 2) {
LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
} }
} }
if (HWY_UNLIKELY(imc != mc)) { if (HWY_UNLIKELY(imc != mc)) {
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
} }
return; return;
} }
@ -377,17 +389,17 @@ class MMKernel {
HWY_DASSERT(mr == 4); HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) { if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) { for (; imc <= mc - 4; imc += 4) {
LoopOverKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
} }
} }
const size_t remainder_mc = mc - imc; const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4); HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) { if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
imc += 2; imc += 2;
} }
if (HWY_UNLIKELY(remainder_mc & 1)) { if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args); LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
imc += 1; imc += 1;
} }
HWY_DASSERT(imc == mc); HWY_DASSERT(imc == mc);
@ -484,11 +496,11 @@ class MMKernel {
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
// BF16 so we can load directly without `Decompress2`, which is expensive for // BF16 so we can load directly without `Decompress2`, which is expensive for
// NUQ and requires 2x unrolling, which requires more loads. // NUQ and requires 2x unrolling, which requires more loads.
template <size_t kRowsAC, class Tag> template <size_t kRowsAC, class Tag, typename TC>
static HWY_INLINE void LoopOverKC(const RowPtrBF& A_view, static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
const RowPtrBF& B_view, size_t row_ac, size_t row_ac, size_t imc, size_t col_c,
size_t imc, size_t col_c, size_t kc, size_t kc, Tag tag, const MMArgs& args,
Tag tag, const MMArgs& args) { const RowPtr<TC>& C) {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>; using VBF = hn::Vec<decltype(dbf)>;
const size_t NBF = hn::Lanes(dbf); const size_t NBF = hn::Lanes(dbf);
@ -602,11 +614,11 @@ class MMKernel {
if (args.add) { if (args.add) {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()( MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args); C31, C32, C33, row_ac, col_c, args, C);
} else { } else {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()( MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args); C31, C32, C33, row_ac, col_c, args, C);
} }
} else { } else {
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()( MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
@ -627,40 +639,44 @@ class MMScaleDemoteAdd {
// TODO: fuse with subsequent operations - function pointer? // TODO: fuse with subsequent operations - function pointer?
// Although this region in `outputs.C` is not touched again, streaming stores // Although this region in `outputs.C` is not touched again, streaming stores
// do not help on SKX and Zen4. TODO: re-check this. // do not help on SKX and Zen4. TODO: re-check this.
template <typename TC>
static HWY_INLINE void FillC(const IndexRange& range_mc, static HWY_INLINE void FillC(const IndexRange& range_mc,
const IndexRange& range_nc, const MMArgs& args) { const IndexRange& range_nc, const MMArgs& args,
const RowPtr<TC>& C) {
size_t row_c = range_mc.begin(); size_t row_c = range_mc.begin();
if (args.add) { if (args.add) {
constexpr bool kAdd = true; constexpr bool kAdd = true;
if (range_mc.Num() >= 4) { if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args); Do4Rows<kAdd>(row_c, range_nc, args, C);
} }
} }
for (; row_c < range_mc.end(); ++row_c) { for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args); Do1Row<kAdd>(row_c, range_nc, args, C);
} }
} else { } else {
constexpr bool kAdd = false; constexpr bool kAdd = false;
if (range_mc.Num() >= 4) { if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args); Do4Rows<kAdd>(row_c, range_nc, args, C);
} }
} }
for (; row_c < range_mc.end(); ++row_c) { for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args); Do1Row<kAdd>(row_c, range_nc, args, C);
} }
} }
} }
private: private:
// Unrolled for 4 rows to reduce the number of loads from `add`. // Unrolled for 4 rows to reduce the number of loads from `add`.
template <bool kAdd> template <bool kAdd, typename TC>
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
const MMArgs& args) { const MMArgs& args, const RowPtr<TC>& C) {
const hn::ScalableTag<double> dd; const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>; using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd); const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale); const VD vscale = hn::Set(dd, args.scale);
@ -669,10 +685,10 @@ class MMScaleDemoteAdd {
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0); TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
float* HWY_RESTRICT cr1 = args.C.Row(row_c + 1); TC* HWY_RESTRICT cr1 = C.Row(row_c + 1);
float* HWY_RESTRICT cr2 = args.C.Row(row_c + 2); TC* HWY_RESTRICT cr2 = C.Row(row_c + 2);
float* HWY_RESTRICT cr3 = args.C.Row(row_c + 3); TC* HWY_RESTRICT cr3 = C.Row(row_c + 3);
// We manually unroll 2x for higher IPC in batch=1. // We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin(); size_t col_c = range_nc.begin();
@ -713,15 +729,24 @@ class MMScaleDemoteAdd {
m30 = hn::Mul(d30, vscale); m30 = hn::Mul(d30, vscale);
m31 = hn::Mul(d31, vscale); m31 = hn::Mul(d31, vscale);
} }
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f01 = hn::DemoteTo(df, m01);
const VF f10 = hn::DemoteTo(df, m10);
const VF f11 = hn::DemoteTo(df, m11);
const VF f20 = hn::DemoteTo(df, m20);
const VF f21 = hn::DemoteTo(df, m21);
const VF f30 = hn::DemoteTo(df, m30);
const VF f31 = hn::DemoteTo(df, m31);
// Note that Stream is neutral on SKX and harmful on Zen4. // Note that Stream is neutral on SKX and harmful on Zen4.
hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c); hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND); hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
hn::Store(hn::DemoteTo(df, m10), df, cr1 + col_c); hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c);
hn::Store(hn::DemoteTo(df, m11), df, cr1 + col_c + ND); hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND);
hn::Store(hn::DemoteTo(df, m20), df, cr2 + col_c); hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c);
hn::Store(hn::DemoteTo(df, m21), df, cr2 + col_c + ND); hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND);
hn::Store(hn::DemoteTo(df, m30), df, cr3 + col_c); hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c);
hn::Store(hn::DemoteTo(df, m31), df, cr3 + col_c + ND); hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND);
} }
} }
@ -750,25 +775,31 @@ class MMScaleDemoteAdd {
m20 = hn::Mul(d20, vscale); m20 = hn::Mul(d20, vscale);
m30 = hn::Mul(d30, vscale); m30 = hn::Mul(d30, vscale);
} }
hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining); // First convert f64 to f32.
hn::StoreN(hn::DemoteTo(df, m10), df, cr1 + col_c, remaining); const VF f00 = hn::DemoteTo(df, m00);
hn::StoreN(hn::DemoteTo(df, m20), df, cr2 + col_c, remaining); const VF f10 = hn::DemoteTo(df, m10);
hn::StoreN(hn::DemoteTo(df, m30), df, cr3 + col_c, remaining); const VF f20 = hn::DemoteTo(df, m20);
const VF f30 = hn::DemoteTo(df, m30);
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining);
} }
} }
// Same as above but handles a single row (for remainder rows). // Same as above but handles a single row (for remainder rows).
template <bool kAdd> template <bool kAdd, typename TC>
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
const MMArgs& args) { const MMArgs& args, const RowPtr<TC>& C) {
const hn::ScalableTag<double> dd; const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>; using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd); const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale); const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0); TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
// We manually unroll 2x for higher IPC in batch=1. // We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin(); size_t col_c = range_nc.begin();
@ -791,9 +822,12 @@ class MMScaleDemoteAdd {
m00 = hn::Mul(d00, vscale); m00 = hn::Mul(d00, vscale);
m01 = hn::Mul(d01, vscale); m01 = hn::Mul(d01, vscale);
} }
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f01 = hn::DemoteTo(df, m01);
// Note that Stream is neutral on SKX and harmful on Zen4. // Note that Stream is neutral on SKX and harmful on Zen4.
hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c); hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND); hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
} }
} }
@ -813,7 +847,9 @@ class MMScaleDemoteAdd {
} else { } else {
m00 = hn::Mul(d00, vscale); m00 = hn::Mul(d00, vscale);
} }
hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining); // First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
} }
} }
}; // MMScaleDemoteAdd }; // MMScaleDemoteAdd
@ -849,20 +885,21 @@ class MMPerPackage {
// B is decompressed several call layers lower, but not all member functions // B is decompressed several call layers lower, but not all member functions
// depend on TB, so pass it as an argument instead of templating the class. // depend on TB, so pass it as an argument instead of templating the class.
template <typename TB> template <typename TB, typename TC>
HWY_NOINLINE void operator()(const ConstMat<TB>& B) const { HWY_NOINLINE void operator()(const ConstMat<TB>& B,
const RowPtr<TC>& C) const {
// TODO: include NUQ tables? NumPacked in ConstMat? // TODO: include NUQ tables? NumPacked in ConstMat?
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows; const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
switch (order_) { switch (order_) {
case MMOrder::kNT: case MMOrder::kNT:
return DoNT(B, num_packed_B); return DoNT(B, num_packed_B, C);
case MMOrder::kNT_K: case MMOrder::kNT_K:
return DoNT_K(B, num_packed_B); return DoNT_K(B, num_packed_B, C);
case MMOrder::kNT_MT: case MMOrder::kNT_MT:
return DoNT_MT(B, num_packed_B); return DoNT_MT(B, num_packed_B, C);
case MMOrder::kNT_MT_K: case MMOrder::kNT_MT_K:
return DoNT_MT_K(B, num_packed_B); return DoNT_MT_K(B, num_packed_B, C);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
@ -878,13 +915,14 @@ class MMPerPackage {
// Granularity of `ForNP`. B rows produce C columns, so we // Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing. // want a multiple of the line size to prevent false sharing.
static size_t MultipleNP() { static size_t MultipleNP(size_t sizeof_TC) {
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof(float)); return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC);
} }
// Single M and K, parallel N. Fills all of C directly. // Single M and K, parallel N. Fills all of C directly.
template <typename TB> template <typename TB, typename TC>
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B) const { HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT", args_); zone.MaybeEnter("MM.NT", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -897,7 +935,7 @@ class MMPerPackage {
// Similar to `loop_nc` below, but here we hoisted `A_view`. // Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
range_np_, MultipleNP(), inner_tasks_, pkg_idx_, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, K, B_stride); const RowPtrBF B_view(B_storage, K, B_stride);
@ -910,7 +948,7 @@ class MMPerPackage {
DecompressB(B, num_packed_B, row_b, range_K, B_view); DecompressB(B, num_packed_B, row_b, range_K, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_); args_, C);
} }
}); });
@ -918,8 +956,9 @@ class MMPerPackage {
} }
// Single M, parallel N, sequential K. Fills all of partial. // Single M, parallel N, sequential K. Fills all of partial.
template <typename TB> template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B) const { HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_K", args_); zone.MaybeEnter("MM.NT_K", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -942,13 +981,13 @@ class MMPerPackage {
zone.MaybeEnter("MM.NT_K.DecB", args_); zone.MaybeEnter("MM.NT_K.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_kc, B_view); DecompressB(B, num_packed_B, row_b, range_kc, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
args_); C);
} }
}; };
args_.env->parallel.ForNP( args_.env->parallel.ForNP(
range_np_, MultipleNP(), inner_tasks_, pkg_idx_, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR { [&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
@ -965,13 +1004,13 @@ class MMPerPackage {
MMZone fill_zone; MMZone fill_zone;
if (out_ == MMOut::kCopy) { if (out_ == MMOut::kCopy) {
fill_zone.MaybeEnter("MM.NT_K.FillC", args_); fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_); MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C);
} else if (out_ == MMOut::kParM) { } else if (out_ == MMOut::kParM) {
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_); fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
args_.env->parallel.ForRangeMC( args_.env->parallel.ForRangeMC(
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR { range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
args_); args_, C);
}); });
} else { } else {
HWY_UNREACHABLE; // kDirect is only used with kNT. HWY_UNREACHABLE; // kDirect is only used with kNT.
@ -980,8 +1019,9 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K. // Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TB> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B) const { HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT", args_); zone.MaybeEnter("MM.NT_MT", args_);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -1006,7 +1046,7 @@ class MMPerPackage {
DecompressB(B, num_packed_B, row_b, range_K, B_view); DecompressB(B, num_packed_B, row_b, range_K, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_); args_, C);
} }
}); });
@ -1015,8 +1055,9 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B) const { HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT_K", args_); zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
@ -1039,8 +1080,8 @@ class MMPerPackage {
zone.MaybeEnter("MM.NT_MT_K.DecB", args_); zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_kc, B_view); DecompressB(B, num_packed_B, row_b, range_kc, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
args_); C);
} }
}; // loop_nc }; // loop_nc
args_.env->parallel.ForRangesMC_NC( args_.env->parallel.ForRangesMC_NC(
@ -1063,7 +1104,7 @@ class MMPerPackage {
HWY_DASSERT(out_ == MMOut::kCopy); HWY_DASSERT(out_ == MMOut::kCopy);
MMZone fill_zone; MMZone fill_zone;
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_); fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_); MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C);
}); });
} }
@ -1083,6 +1124,8 @@ class MMPerPackage {
const IndexRange& range_K) HWY_ATTR { const IndexRange& range_K) HWY_ATTR {
const size_t col0 = range_K.begin(); const size_t col0 = range_K.begin();
const size_t cols = range_K.Num(); const size_t cols = range_K.Num();
// otherwise, padding overwrites neighbors
HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols);
for (size_t row_a : range_M) { for (size_t row_a : range_M) {
const PackedSpan<const TA> from = const PackedSpan<const TA> from =
MakeSpan(A.ptr + A.Row(row_a) + col0, cols); MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
@ -1224,9 +1267,10 @@ struct MMImpl {
// Called from `MatMul` from two places: either with the next autotune config, // Called from `MatMul` from two places: either with the next autotune config,
// or with the best config. // or with the best config.
template <typename TA, typename TB> template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A, static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A,
const ConstMat<TB>& B, const MMArgs& args, const ConstMat<TB>& B, const RowPtr<TC>& C,
const MMArgs& args,
const MMConfig& config) { const MMConfig& config) {
MMZone matmul_zone; MMZone matmul_zone;
matmul_zone.MaybeEnter("MM.DoMatMul", args); matmul_zone.MaybeEnter("MM.DoMatMul", args);
@ -1235,7 +1279,7 @@ struct MMImpl {
args.env->parallel.ForPkg( args.env->parallel.ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B); MMPerPackage(A, args, config, pkg_idx, range_np)(B, C);
}); });
} }
}; };
@ -1260,10 +1304,10 @@ struct MMImpl {
// invalidated by the next call to `MatMul`. // invalidated by the next call to `MatMul`.
// //
// Uses considerable stack space: at least 40 KiB per thread. // Uses considerable stack space: at least 40 KiB per thread.
template <typename TA, typename TB> template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtrF& C) { const RowPtr<TC>& C) {
const size_t M = A.Extents().rows; const size_t M = A.Extents().rows;
const size_t K = A.Extents().cols; const size_t K = A.Extents().cols;
const size_t N = B.Extents().rows; const size_t N = B.Extents().rows;
@ -1281,15 +1325,16 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
// invalidates `MMAutoTune::Best()` // invalidates `MMAutoTune::Best()`
index = env.per_key.size(); index = env.per_key.size();
env.per_key.push_back(MMPerKey(max_packages, N, kNR, env.parallel)); env.per_key.push_back(
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
} }
MMPerKey& per_key = env.per_key[index]; MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune; MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add, const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add,
env.storage.Partial(), C); env.storage.Partial());
if (HWY_LIKELY(tuner.Best())) { if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, args, *tuner.Best()); MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
return &per_key; return &per_key;
} }
@ -1306,13 +1351,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
HWY_ASSERT(N % kNR == 0); HWY_ASSERT(N % kNR == 0);
// Negligible CPU time. // Negligible CPU time.
tuner.SetCandidates(MMCandidates(M, K, N, MMKernel::kMaxMR, kNR, tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR,
per_key.ranges_np, env.print_config)); per_key.ranges_np, env.print_config));
} }
const MMConfig& cfg = tuner.NextConfig(); const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, args, cfg); MMImpl::DoMatMul(A, B, C, args, cfg);
const uint64_t t1 = const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) / const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /

View File

@ -60,10 +60,13 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables. // and holds most of their arguments in member variables.
class GenerateCandidates { class GenerateCandidates {
public: public:
GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr, GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config) const IndexRangePartition& ranges_np, bool print_config)
: M_(M), : M_(M),
K_(K), K_(K),
N_(N),
sizeof_TC_(sizeof_TC),
max_mr_(max_mr), max_mr_(max_mr),
nr_(nr), nr_(nr),
// These influence kc/nc, but are also stored in `MMConfig` for // These influence kc/nc, but are also stored in `MMConfig` for
@ -71,7 +74,7 @@ class GenerateCandidates {
// is likely still in L1, but we expect K > 1000 and might as well round // is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size. // up to the line size.
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))), kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
nc_multiple_(Allocator::StepBytes() / sizeof(float)), nc_multiple_(Allocator::StepBytes() / sizeof_TC),
ranges_np_(ranges_np), ranges_np_(ranges_np),
print_config_(print_config) {} print_config_(print_config) {}
@ -88,7 +91,7 @@ class GenerateCandidates {
for (size_t nc : NC(mr, mc, kc, order)) { for (size_t nc : NC(mr, mc, kc, order)) {
for (int inner_tasks : all_inner_tasks) { for (int inner_tasks : all_inner_tasks) {
for (MMOut out : all_outs) { for (MMOut out : all_outs) {
const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_, const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, out, inner_tasks); nc_multiple_, order, out, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks(); const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks(); const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
@ -114,7 +117,7 @@ class GenerateCandidates {
private: private:
using SizeVec = std::vector<size_t>; using SizeVec = std::vector<size_t>;
// How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may // How many rows of A per call to `MMKernel::LoopKC`. Lower values may
// be better for SIMD targets with fewer registers. // be better for SIMD targets with fewer registers.
SizeVec MR() const { SizeVec MR() const {
const int64_t target = hwy::DispatchedTarget(); const int64_t target = hwy::DispatchedTarget();
@ -153,7 +156,7 @@ class GenerateCandidates {
// The number of A and B columns to read between updating `partial`. // The number of A and B columns to read between updating `partial`.
SizeVec KC(size_t mr, MMOrder order) const { SizeVec KC(size_t mr, MMOrder order) const {
// `LoopOverKC` handles up to `mr` rows of A. // `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(M_, mr); const size_t rows_a = HWY_MIN(M_, mr);
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector // After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
@ -161,9 +164,9 @@ class GenerateCandidates {
// is important that B fits in L1, because batch=1 only has a single row of // is important that B fits in L1, because batch=1 only has a single row of
// A and thus no reuse of the packed B. When L1-resident, we can use the // A and thus no reuse of the packed B. When L1-resident, we can use the
// separate `DecompressAndZeroPad` to write `kc` columns, rather than having // separate `DecompressAndZeroPad` to write `kc` columns, rather than having
// to integrate `Decompress2` into `LoopOverKC`, which is less efficient for // to integrate `Decompress2` into `LoopKC`, which is less efficient for
// TB=NUQ due to less amortization of the table loads. Due to the low L1 // TB=NUQ due to less amortization of the table loads. Due to the low L1
// latency, the packing is still effectively fused into `LoopOverKC`. It may // latency, the packing is still effectively fused into `LoopKC`. It may
// be better to round up and accept a few L2 accesses in exchange for // be better to round up and accept a few L2 accesses in exchange for
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not // fewer loops over K, and thus fewer writes to `partial`. Hence we do not
// subtract the output and buf, and allow using more than the actual L1 // subtract the output and buf, and allow using more than the actual L1
@ -255,7 +258,7 @@ class GenerateCandidates {
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
const size_t np_max = ranges_np_.TaskSize(); const size_t np_max = ranges_np_.TaskSize();
size_t nc_max = np_max; size_t nc_max = np_max;
const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double); const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double);
// Only if there will be reuse of B: choose the largest `nc_max` (C cols) // Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded. // Otherwise, leave it unbounded.
@ -350,6 +353,8 @@ class GenerateCandidates {
const size_t M_; const size_t M_;
const size_t K_; const size_t K_;
const size_t N_;
const size_t sizeof_TC_;
const size_t max_mr_; const size_t max_mr_;
const size_t nr_; const size_t nr_;
@ -365,23 +370,25 @@ class GenerateCandidates {
} // namespace } // namespace
// Facade to avoid exposing `GenerateCandidates` in the header. // Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr, std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
size_t nr, size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np,
bool print_config) { bool print_config) {
return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)(); return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np,
print_config)();
} }
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package // memory accesses or false sharing, unless there are insufficient per-package
// rows for that. // rows for that.
static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) { static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
size_t np_multiple = Allocator::QuantumBytes() / sizeof(float); size_t num_packages) {
size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
// `N` < 4096, this can cause significant load imbalance. If split unevenly, // `N` < 4096, this can cause significant load imbalance. If split unevenly,
// choose a smaller multiple. // choose a smaller multiple.
if (N % (np_multiple * num_packages)) { if (N % (np_multiple * num_packages)) {
const size_t min_multiple = Allocator::LineBytes() / sizeof(float); const size_t min_multiple = Allocator::LineBytes() / sizeof_TC;
np_multiple = np_multiple =
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple); PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
if (HWY_UNLIKELY(np_multiple == 0)) { if (HWY_UNLIKELY(np_multiple == 0)) {
@ -398,10 +405,10 @@ static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
} }
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
size_t nr) const { size_t sizeof_TC, size_t nr) const {
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages()); const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
return StaticPartition(IndexRange(0, N), num_packages, return StaticPartition(IndexRange(0, N), num_packages,
NPMultiple(N, nr, num_packages)); NPMultiple(N, sizeof_TC, nr, num_packages));
} }
MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) { MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) {

View File

@ -60,7 +60,7 @@ class MMParallel {
// Initial static partitioning of B rows across packages. // Initial static partitioning of B rows across packages.
IndexRangePartition RangesOfNP(size_t max_packages, size_t N, IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
size_t nr) const; size_t sizeof_TC, size_t nr) const;
// For `BindB` and `BindC`. // For `BindB` and `BindC`.
size_t Node(size_t pkg_idx) const { size_t Node(size_t pkg_idx) const {
@ -170,13 +170,13 @@ class MMParallel {
NestedPools& pools_; NestedPools& pools_;
}; };
template <typename T> // float for C, double for partial template <typename TC> // BF16/float for C, double for partial
void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) { void BindC(size_t M, const RowPtr<TC>& C, MMParallel& parallel) {
if (!Allocator::ShouldBind()) return; if (!Allocator::ShouldBind()) return;
const IndexRangePartition ranges_np = const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR); parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(T); const size_t quantum = Allocator::QuantumBytes() / sizeof(TC);
bool ok = true; bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx); const IndexRange& cols_c = ranges_np.Range(pkg_idx);
@ -185,7 +185,7 @@ void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
// BindRowsToPackageNodes may not be page-aligned. // BindRowsToPackageNodes may not be page-aligned.
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum); const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum); const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T), ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
node); node);
} }
} }
@ -355,11 +355,11 @@ static inline const char* StringFromParA(MMParA par_a) {
class MMConfig { class MMConfig {
public: public:
MMConfig() = default; // for std::vector MMConfig() = default; // for std::vector
// `mr` is the number of A rows per call to `MMKernel::LoopOverKC`. // `mr` is the number of A rows per call to `MMKernel::LoopKC`.
// `MMOrder` is how to parallelize the outer loops. // `MMOrder` is how to parallelize the outer loops.
// `MMOut` is how/whether to parallelize filling the C result. // `MMOut` is how/whether to parallelize filling the C result.
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`. // `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc, MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out, size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
int inner_tasks) int inner_tasks)
: mr_(static_cast<uint32_t>(mr)), : mr_(static_cast<uint32_t>(mr)),
@ -381,7 +381,7 @@ class MMConfig {
if (kc != K && (kc % kc_multiple) != 0) { if (kc != K && (kc % kc_multiple) != 0) {
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple); HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
} }
if (nc % nc_multiple != 0) { if (nc != N && (nc % nc_multiple) != 0) {
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple); HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
} }
HWY_DASSERT(StringFromOrder(order_) != nullptr); HWY_DASSERT(StringFromOrder(order_) != nullptr);
@ -428,8 +428,8 @@ class MMConfig {
static_assert(sizeof(MMConfig) == 32); // for faster indexing static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop) #pragma pack(pop)
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr, std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
size_t nr, size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np,
bool print_config); bool print_config);
@ -588,8 +588,9 @@ class MMKeys {
// Per-MatMul-shape state. // Per-MatMul-shape state.
struct MMPerKey { struct MMPerKey {
MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel) MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
: ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {} MMParallel& parallel)
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {}
// Only profile if enabled and the main autotuner finished (the par_a // Only profile if enabled and the main autotuner finished (the par_a
// autotuner is per-package and we want to avoid synchronization). // autotuner is per-package and we want to avoid synchronization).
@ -623,18 +624,16 @@ struct MatMulEnv {
std::vector<MMPerKey> per_key; std::vector<MMPerKey> per_key;
}; };
// Arguments to MatMul() that are independent of the A/B type. // Arguments to MatMul() that are independent of the A/B/C types.
// Reduces register pressure compared to individual values/references. // Reduces register pressure compared to individual values/references.
struct MMArgs { struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add, const RowPtrD& partial, const float* HWY_RESTRICT add, const RowPtrD& partial)
const RowPtrF& C)
: env(&env), : env(&env),
per_key(&per_key), per_key(&per_key),
scale(scale), scale(scale),
add(add), add(add),
partial(partial), partial(partial) {}
C(C) {}
MatMulEnv* env; MatMulEnv* env;
MMPerKey* per_key; MMPerKey* per_key;
@ -643,7 +642,6 @@ struct MMArgs {
const float* HWY_RESTRICT add; const float* HWY_RESTRICT add;
// Same size as C, threads write at false-sharing-free granularity. // Same size as C, threads write at false-sharing-free granularity.
RowPtrD partial; RowPtrD partial;
RowPtrF C;
}; };
// Wrapper over hwy::Zone that is only enabled when autotuning finished. // Wrapper over hwy::Zone that is only enabled when autotuning finished.
@ -683,22 +681,22 @@ struct MMZone {
// `ofs` required for compressed T. // `ofs` required for compressed T.
template <typename T> template <typename T>
struct ConstMat { struct ConstMat {
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0) ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0)
: ptr(ptr), extents(extents), ofs(ofs) { : ptr(ptr), extents(extents), stride(stride), ofs(ofs) {
HWY_DASSERT(ptr != nullptr); HWY_DASSERT(ptr != nullptr);
HWY_DASSERT(stride >= extents.cols);
} }
// TODO: support stride for page alignment.
size_t Row(size_t r) const { size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) { if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) { if (r >= extents.rows) {
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows); HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
} }
} }
return ofs + extents.cols * r; return ofs + r * stride;
} }
const Extents2D& Extents() const { return extents; } const Extents2D& Extents() const { return extents; }
size_t Stride() const { return extents.cols; } size_t Stride() const { return stride; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a // Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0. // subrange of the original rows starting at row 0.
@ -709,6 +707,7 @@ struct ConstMat {
const T* HWY_RESTRICT ptr; const T* HWY_RESTRICT ptr;
Extents2D extents; Extents2D extents;
size_t stride;
// `scale` allows expanding the smaller range of `SfpStream` to the original // `scale` allows expanding the smaller range of `SfpStream` to the original
// values. MatFromWeights sets this from `MatPtr`. // values. MatFromWeights sets this from `MatPtr`.
@ -721,9 +720,9 @@ struct ConstMat {
// For deducing T. // For deducing T.
template <typename T> template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride,
size_t ofs = 0) { size_t ofs = 0) {
return ConstMat<T>(ptr, extents, ofs); return ConstMat<T>(ptr, extents, stride, ofs);
} }
// For A argument to MatMul (activations). // For A argument to MatMul (activations).
@ -732,22 +731,25 @@ ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) { const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize()); HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()), return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols())); Extents2D(batch_size, row_vectors.Cols()),
row_vectors.Stride());
} }
template <typename T> template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) { ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs); ConstMat<T> mat =
MakeConstMat(const_cast<T*>(m.data()), m.Extents(), m.Stride(), ofs);
mat.scale = m.scale(); mat.scale = m.scale();
return mat; return mat;
} }
template <typename TB> template <typename TB>
void BindB(size_t N, const ConstMat<TB>& B, MMParallel& parallel) { void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
MMParallel& parallel) {
if (!Allocator::ShouldBind()) return; if (!Allocator::ShouldBind()) return;
const IndexRangePartition ranges_np = const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR); parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB); const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx); const IndexRange& rows_b = ranges_np.Range(pkg_idx);

View File

@ -67,7 +67,7 @@ using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range. // Generates inputs: deterministic, within max SfpStream range.
template <typename MatT> template <typename MatT>
MatStoragePtr<MatT> GenerateMat(const Extents2D extents, MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
auto mat = auto mat =
@ -112,12 +112,12 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
} }
// Returns 1-norm, used for estimating tolerable numerical differences. // Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) { double MaxRowAbsSum(const RowVectorBatch<float>& a) {
double max_row_abs_sum = 0.0; double max_row_abs_sum = 0.0;
for (size_t r = 0; r < extents.rows; r++) { for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a + r * extents.cols; const float* row = a.Batch(r);
double row_abs_sum = 0.0; double row_abs_sum = 0.0;
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]); row_abs_sum += hwy::ScalarAbs(row[c]);
} }
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum); max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
@ -126,41 +126,52 @@ double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
} }
// Returns the maximum absolute value of `a`. // Returns the maximum absolute value of `a`.
float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) { float MaxAbs(const RowVectorBatch<float>& a) {
float max_abs = 0.0f; float max_abs = 0.0f;
for (size_t c = 0; c < extents.cols; c++) { for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < extents.rows; r++) { for (size_t r = 0; r < a.BatchSize(); r++) {
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c])); const float* row = a.Batch(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
} }
} }
return max_abs; return max_abs;
} }
// B is already transposed. // B is already transposed.
template <typename TA, typename TB> template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B, void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
const RowPtrF& C_slow, const RowPtrF& C) { const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t num_a = A.extents.Area(); const size_t cols = A.extents.cols;
const size_t num_b = B.extents.Area(); const size_t B_rows = B.extents.rows;
const size_t N = hn::Lanes(df);
// Round up for DecompressAndZeroPad. // Round up for DecompressAndZeroPad.
FloatPtr a = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_a, N)); RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
FloatPtr b_trans = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_b, N)); RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
HWY_ASSERT(a && b_trans); RowVectorBatch<float> c_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
RowVectorBatch<float> c_slow_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
HWY_ASSERT(A.ofs == 0 && B.ofs == 0); HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a); for (size_t m = 0; m < A.extents.rows; ++m) {
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b); DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
a_batch.Batch(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Batch(m), B_rows);
}
for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0,
b_trans_batch.Batch(n), cols);
}
// MatMul rounds inputs to BF16, so error is proportional to the max input // MatMul rounds inputs to BF16, so error is proportional to the max input
// magnitude, but also to f32 accumulation of rows in A and B. // magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a.get(), A.Extents()) * const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
MaxRowAbsSum(b_trans.get(), B.Extents()); const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const float max_abs =
MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents());
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>()); const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>()); const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
double tolerance = 10 * norm * eps_f32; double tolerance = 12 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the // Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there. // tolerance there.
if (IsF32<TA>() && IsF32<TB>()) { if (IsF32<TA>() && IsF32<TB>()) {
@ -169,30 +180,38 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
if (tolerance > 500.0) { if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
} }
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
for (size_t r = 0; r < A.extents.rows; r++) { for (size_t r = 0; r < A.extents.rows; r++) {
const float* expected_row = C_slow.Row(r); const float* expected_row = c_slow_batch.Batch(r);
const float* actual_row = C.Row(r); const float* actual_row = c_batch.Batch(r);
for (size_t c = 0; c < B.extents.rows; c++) { for (size_t c = 0; c < B.extents.rows; c++) {
const double expected_value = static_cast<double>(expected_row[c]); const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]); const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
if (!(expected_value - tolerance <= actual_value && if (!in_range) {
actual_value <= expected_value + tolerance)) { const double max = HWY_MAX(expected_value, actual_value);
HWY_ABORT( const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f\n", "tolerance %f rel %E max_rel %E\n",
r, c, expected_value, actual_value, norm, max_abs, tolerance); r, c, expected_value, actual_value, norm, max_abs,
tolerance, rel, max_rel);
}
} }
} }
} }
} }
// B is already transposed. // B is already transposed.
template <typename TA, typename TB> template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B, HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
const float* HWY_RESTRICT add_row, MatMulEnv& env, const float* HWY_RESTRICT add_row, MatMulEnv& env,
const RowPtrF& C) { const RowPtr<TC>& C) {
// TA can be any Packed except NuqStream because it uses pointer // TA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not // arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs. // support a v_ofs.
@ -200,7 +219,8 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
const float scale = A.scale * B.scale; const float scale = A.scale * B.scale;
const hn::ScalableTag<float> df; // lane type is ignored const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const TB> b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area()); const PackedSpan<const TB> b_span =
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
const IndexRange all_rows_c(0, A.Extents().rows); const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols()); const IndexRange all_cols_c(0, C.Cols());
@ -219,12 +239,12 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
get_col_c, all_clusters, get_col_c, all_clusters,
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR { [&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : rows_c) { for (size_t r : rows_c) {
float* HWY_RESTRICT C_row = C.Row(r); TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) { for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f; const float add = add_row ? add_row[c] : 0.0f;
C_row[c] = C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.extents.cols, add + scale * Dot(df, b_span, c * B.Stride(),
A.ptr + A.Row(r), A.extents.cols); A.ptr + A.Row(r), A.extents.cols));
} }
} }
}); });
@ -239,14 +259,15 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
} }
template <typename TA, typename TB = TA> template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env) { MatMulEnv& env, int line) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(); hwy::ThreadPool& pool = env.parallel.Pools().Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac, fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>()); rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
TypeName<TC>());
env.print_config = true; env.print_config = false; // Too verbose.
env.print_best = true; env.print_best = true;
const Extents2D A_extents(rows_ac, cols_a_rows_b); const Extents2D A_extents(rows_ac, cols_a_rows_b);
@ -255,8 +276,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool); MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool); MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents); RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents); RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
HWY_ASSERT(a && b_trans); HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage; std::unique_ptr<MatStorageT<float>> add_storage;
@ -269,14 +290,14 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const auto A = ConstMatFromWeights(*a); const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans); const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr; const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch); const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtrF C = RowPtrFromBatch(c_batch); const RowPtr<TC> C = RowPtrFromBatch(c_batch);
MatMulSlow(A, B, add_row, env, C_slow); MatMulSlow(A, B, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths. // A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) { for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(A, B, add_row, env, C); MMPerKey* per_key = MatMul(A, B, add_row, env, C);
AssertClose(A, B, C_slow, C); AssertClose(A, B, C_slow, C, line);
if (per_key->autotune.Best()) break; if (per_key->autotune.Best()) break;
} }
} }
@ -311,7 +332,7 @@ void TestTiny() {
for (size_t M = 1; M <= 12; ++M) { for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) { for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += max_packages * 4) { for (size_t N = 4; N <= 64; N += max_packages * 4) {
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env); TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
} }
} }
} }
@ -332,56 +353,69 @@ void TestAllMatMul() {
Allocator::Init(pools.Topology(), /*enable_bind=*/true); Allocator::Init(pools.Topology(), /*enable_bind=*/true);
MatMulEnv env(pools); MatMulEnv env(pools);
// Sizes seen in gemma_test 2B. // Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env); TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env); // TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env); // TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env); // TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env); // TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env, __LINE__);
TestMatMul<F32>(5, 2048, 512, /*add=*/false, env); // TestMatMul<F32>(5, 2048, 512, /*add=*/false, env, __LINE__);
TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env); // TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env, __LINE__);
TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env); // TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env, __LINE__);
TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env); // TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env, __LINE__);
// medium-sized square // medium-sized square, f32 vs bf16 for A, B, C; plus add.
TestMatMul<F32>(512, 512, 512, /*add=*/false, env); TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env); TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env); TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env); TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env); TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env); TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
// minimal non-square test. kColsARowsB must be at least 2 vectors. // minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<F32>(35, 128, 32, /*add=*/false, env); TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env); TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env); TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env); TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env); TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env); TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32>(4, 128, 32, /*add=*/true, env); TestMatMul<F32>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env); TestMatMul<BF16>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env); TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env); TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env); TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env); TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32>(3, 128, 32, /*add=*/false, env); TestMatMul<F32>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env); TestMatMul<BF16>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env); TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env); TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env); TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env); TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32>(2, 128, 64, /*add=*/true, env); TestMatMul<F32>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env); TestMatMul<BF16>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env); TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env); TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env); TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env); TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 128, 32, /*add=*/false, env); TestMatMul<F32>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env); TestMatMul<BF16>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env); TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env); TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env); TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env); TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)