Support bf16 output of Matmul

Adds Stride to ConstMat, to support decompression of C output for test
matmul_test: add line numbers to output
Also ignore "N is not a multiple of nc" when N==nc
PiperOrigin-RevId: 731096662
This commit is contained in:
Jan Wassenberg 2025-02-25 17:52:50 -08:00 committed by Copybara-Service
parent b3b4b9f92f
commit 2bdf26d81d
6 changed files with 353 additions and 262 deletions

View File

@ -734,7 +734,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
// Hidden layer -> output layer.
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim),
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
}
@ -773,8 +774,9 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
multiplier.Row(0), ff_hidden_dim * num_interleaved);
// Hidden layer -> output layer.
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
auto activations_mat = MakeConstMat(hidden_activations.Row(0),
Extents2D(num_interleaved, ff_hidden_dim),
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
}

View File

@ -133,21 +133,22 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// Generates inputs and prints observed throughput of MatMul.
// M = A rows, K = A cols, N = C cols.
template <typename MatTA, typename MatTB = MatTA>
template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n");
}
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N,
add, TypeName<MatTA>(), TypeName<MatTB>());
fprintf(stderr,
"BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", //
M, K, N, add, TypeName<TA>(), TypeName<TB>(), TypeName<TC>());
const Extents2D A_extents(M, K);
const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N);
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
@ -156,14 +157,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
add_storage->set_scale(1.0f);
}
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
HWY_ASSERT(a && b_trans);
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C = RowPtrFromBatch(c_batch);
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
// Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12;
@ -173,7 +174,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(B_extents.rows, B, env.parallel);
BindB(B_extents.rows, sizeof(TC), B, env.parallel);
BindC(A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault;
@ -191,7 +192,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
per_key = MatMul(A, B, add_row, env, C);
const double t1 = hwy::platform::Now();
double elapsed = t1 - t0;
keep += C.Row(0)[hwy::Unpredictable1()];
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);
// Only record times after autotuning finished.
if (per_key->autotune.Best()) times.push_back(elapsed);
@ -229,8 +230,8 @@ void BenchAllMatMul() {
for (size_t batch_size : {1, 4, 128, 512}) {
constexpr bool kAdd = false;
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env);
BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env);
}
PROFILER_PRINT_RESULTS();

View File

@ -46,7 +46,7 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register.
template <class DF, class DBF = hn::Repartition<hwy::bfloat16_t, DF>>
template <class DF, class DBF = hn::Repartition<BF16, DF>>
static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
// Promoting odd means clearing the lower 16 bits. Doing this via AND
// requires a second input vector, which we prefer to avoid due to high
@ -70,6 +70,16 @@ static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
#endif
}
// Converts from float intermediate to MatMul output type `TC`.
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_F32_D(DC)>
hn::Vec<DC> TCFromF32(DC /*dc*/, hn::Vec<DF> vf) {
return vf;
}
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_BF16_D(DC)>
hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
return hn::DemoteTo(dc, vf);
}
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
// first kc result to partial, or accumulating the next kc result into partial
@ -93,14 +103,14 @@ class MMStoreHorizontalSumsIntoC {
// Thus we compute the horizontal sums of each `Crc`. The elements may be
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
// this does not change their horizontal sum.
template <class DF, class VF = hn::Vec<DF>>
template <class DF, class VF = hn::Vec<DF>, typename TC>
HWY_INLINE void operator()(DF df, //
VF C00, VF C01, VF C02, VF C03, //
VF C10, VF C11, VF C12, VF C13, //
VF C20, VF C21, VF C22, VF C23, //
VF C30, VF C31, VF C32, VF C33, //
const size_t row_c, const size_t col_c,
const MMArgs& args) const {
const MMArgs& args, const RowPtr<TC>& C) const {
float buf[16 * hn::MaxLanes(df)];
const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
@ -136,10 +146,10 @@ class MMStoreHorizontalSumsIntoC {
if constexpr (kAdd) {
vadd = hn::Load(d4, args.add + col_c);
}
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, args.C, row_c, col_c);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, args.C, row_c, col_c);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, args.C, row_c, col_c);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, args.C, row_c, col_c);
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c);
}
private:
@ -155,34 +165,36 @@ class MMStoreHorizontalSumsIntoC {
}
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N,
const float* HWY_RESTRICT buf) {
template <size_t kRow, class DF4, class VF4 = hn::Vec<DF4>>
static HWY_INLINE VF4 MaybeLoad(DF4 df4, size_t N,
const float* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) {
return hn::Load(d4, buf + 4 * kRow * N);
return hn::Load(df4, buf + 4 * kRow * N);
} else {
return hn::Zero(d4);
return hn::Zero(df4);
}
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum,
const float* HWY_RESTRICT buf) {
template <size_t kRow, class DF4, class VF4 = hn::Vec<DF4>>
static HWY_INLINE VF4 MaybeAdd(DF4 df4, size_t N, VF4 sum,
const float* HWY_RESTRICT buf) {
if constexpr (kRow < kRowsAC) {
return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N));
return hn::Add(sum, hn::Load(df4, buf + 4 * kRow * N));
} else {
return sum;
}
}
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
static HWY_INLINE void MaybeScaleAndStore(D4 d4, V4 sum, V4 vscale, V4 vadd,
const RowPtrF& C,
template <size_t kRow, typename TC, class DF4, class VF4 = hn::Vec<DF4>>
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
VF4 vadd, const RowPtr<TC>& C,
const size_t row_c,
const size_t col_c) {
if constexpr (kRow < kRowsAC) {
float* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
hn::Store(hn::MulAdd(sum, vscale, vadd), d4, pos);
TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
const hn::Rebind<TC, DF4> dc4;
const VF4 out = hn::MulAdd(sum, vscale, vadd);
hn::Store(TCFromF32(dc4, out), dc4, pos);
}
}
}; // MMStoreHorizontalSumsIntoC
@ -340,14 +352,14 @@ class MMKernel {
// or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4;
// Calls `LoopOverKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag>
template <class Tag, typename TC>
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t mr, const IndexRange& range_mc,
const size_t row_b, size_t kc, Tag tag,
const MMArgs& args) {
const MMArgs& args, const RowPtr<TC>& C) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t row0 = range_mc.begin();
const size_t mc = range_mc.Num();
@ -356,7 +368,7 @@ class MMKernel {
// M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) {
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
}
return;
}
@ -365,11 +377,11 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) {
LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
}
}
if (HWY_UNLIKELY(imc != mc)) {
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
}
return;
}
@ -377,17 +389,17 @@ class MMKernel {
HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) {
LoopOverKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
}
}
const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
imc += 2;
}
if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
imc += 1;
}
HWY_DASSERT(imc == mc);
@ -484,11 +496,11 @@ class MMKernel {
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
// BF16 so we can load directly without `Decompress2`, which is expensive for
// NUQ and requires 2x unrolling, which requires more loads.
template <size_t kRowsAC, class Tag>
static HWY_INLINE void LoopOverKC(const RowPtrBF& A_view,
const RowPtrBF& B_view, size_t row_ac,
size_t imc, size_t col_c, size_t kc,
Tag tag, const MMArgs& args) {
template <size_t kRowsAC, class Tag, typename TC>
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t row_ac, size_t imc, size_t col_c,
size_t kc, Tag tag, const MMArgs& args,
const RowPtr<TC>& C) {
const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>;
const size_t NBF = hn::Lanes(dbf);
@ -602,11 +614,11 @@ class MMKernel {
if (args.add) {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args);
C31, C32, C33, row_ac, col_c, args, C);
} else {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args);
C31, C32, C33, row_ac, col_c, args, C);
}
} else {
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
@ -627,40 +639,44 @@ class MMScaleDemoteAdd {
// TODO: fuse with subsequent operations - function pointer?
// Although this region in `outputs.C` is not touched again, streaming stores
// do not help on SKX and Zen4. TODO: re-check this.
template <typename TC>
static HWY_INLINE void FillC(const IndexRange& range_mc,
const IndexRange& range_nc, const MMArgs& args) {
const IndexRange& range_nc, const MMArgs& args,
const RowPtr<TC>& C) {
size_t row_c = range_mc.begin();
if (args.add) {
constexpr bool kAdd = true;
if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args);
Do4Rows<kAdd>(row_c, range_nc, args, C);
}
}
for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args);
Do1Row<kAdd>(row_c, range_nc, args, C);
}
} else {
constexpr bool kAdd = false;
if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args);
Do4Rows<kAdd>(row_c, range_nc, args, C);
}
}
for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args);
Do1Row<kAdd>(row_c, range_nc, args, C);
}
}
}
private:
// Unrolled for 4 rows to reduce the number of loads from `add`.
template <bool kAdd>
template <bool kAdd, typename TC>
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
const MMArgs& args) {
const MMArgs& args, const RowPtr<TC>& C) {
const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale);
@ -669,10 +685,10 @@ class MMScaleDemoteAdd {
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0);
float* HWY_RESTRICT cr1 = args.C.Row(row_c + 1);
float* HWY_RESTRICT cr2 = args.C.Row(row_c + 2);
float* HWY_RESTRICT cr3 = args.C.Row(row_c + 3);
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
TC* HWY_RESTRICT cr1 = C.Row(row_c + 1);
TC* HWY_RESTRICT cr2 = C.Row(row_c + 2);
TC* HWY_RESTRICT cr3 = C.Row(row_c + 3);
// We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin();
@ -713,15 +729,24 @@ class MMScaleDemoteAdd {
m30 = hn::Mul(d30, vscale);
m31 = hn::Mul(d31, vscale);
}
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f01 = hn::DemoteTo(df, m01);
const VF f10 = hn::DemoteTo(df, m10);
const VF f11 = hn::DemoteTo(df, m11);
const VF f20 = hn::DemoteTo(df, m20);
const VF f21 = hn::DemoteTo(df, m21);
const VF f30 = hn::DemoteTo(df, m30);
const VF f31 = hn::DemoteTo(df, m31);
// Note that Stream is neutral on SKX and harmful on Zen4.
hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c);
hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND);
hn::Store(hn::DemoteTo(df, m10), df, cr1 + col_c);
hn::Store(hn::DemoteTo(df, m11), df, cr1 + col_c + ND);
hn::Store(hn::DemoteTo(df, m20), df, cr2 + col_c);
hn::Store(hn::DemoteTo(df, m21), df, cr2 + col_c + ND);
hn::Store(hn::DemoteTo(df, m30), df, cr3 + col_c);
hn::Store(hn::DemoteTo(df, m31), df, cr3 + col_c + ND);
hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c);
hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND);
hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c);
hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND);
hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c);
hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND);
}
}
@ -750,25 +775,31 @@ class MMScaleDemoteAdd {
m20 = hn::Mul(d20, vscale);
m30 = hn::Mul(d30, vscale);
}
hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining);
hn::StoreN(hn::DemoteTo(df, m10), df, cr1 + col_c, remaining);
hn::StoreN(hn::DemoteTo(df, m20), df, cr2 + col_c, remaining);
hn::StoreN(hn::DemoteTo(df, m30), df, cr3 + col_c, remaining);
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f10 = hn::DemoteTo(df, m10);
const VF f20 = hn::DemoteTo(df, m20);
const VF f30 = hn::DemoteTo(df, m30);
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining);
hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining);
}
}
// Same as above but handles a single row (for remainder rows).
template <bool kAdd>
template <bool kAdd, typename TC>
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
const MMArgs& args) {
const MMArgs& args, const RowPtr<TC>& C) {
const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc;
using VD = hn::Vec<decltype(dd)>;
using VF = hn::Vec<decltype(df)>;
const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0);
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
// We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin();
@ -791,9 +822,12 @@ class MMScaleDemoteAdd {
m00 = hn::Mul(d00, vscale);
m01 = hn::Mul(d01, vscale);
}
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
const VF f01 = hn::DemoteTo(df, m01);
// Note that Stream is neutral on SKX and harmful on Zen4.
hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c);
hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND);
hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
}
}
@ -813,7 +847,9 @@ class MMScaleDemoteAdd {
} else {
m00 = hn::Mul(d00, vscale);
}
hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining);
// First convert f64 to f32.
const VF f00 = hn::DemoteTo(df, m00);
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
}
}
}; // MMScaleDemoteAdd
@ -849,20 +885,21 @@ class MMPerPackage {
// B is decompressed several call layers lower, but not all member functions
// depend on TB, so pass it as an argument instead of templating the class.
template <typename TB>
HWY_NOINLINE void operator()(const ConstMat<TB>& B) const {
template <typename TB, typename TC>
HWY_NOINLINE void operator()(const ConstMat<TB>& B,
const RowPtr<TC>& C) const {
// TODO: include NUQ tables? NumPacked in ConstMat?
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
switch (order_) {
case MMOrder::kNT:
return DoNT(B, num_packed_B);
return DoNT(B, num_packed_B, C);
case MMOrder::kNT_K:
return DoNT_K(B, num_packed_B);
return DoNT_K(B, num_packed_B, C);
case MMOrder::kNT_MT:
return DoNT_MT(B, num_packed_B);
return DoNT_MT(B, num_packed_B, C);
case MMOrder::kNT_MT_K:
return DoNT_MT_K(B, num_packed_B);
return DoNT_MT_K(B, num_packed_B, C);
default:
HWY_UNREACHABLE;
}
@ -878,13 +915,14 @@ class MMPerPackage {
// Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing.
static size_t MultipleNP() {
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof(float));
static size_t MultipleNP(size_t sizeof_TC) {
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC);
}
// Single M and K, parallel N. Fills all of C directly.
template <typename TB>
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B) const {
template <typename TB, typename TC>
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone;
zone.MaybeEnter("MM.NT", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -897,7 +935,7 @@ class MMPerPackage {
// Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP(
range_np_, MultipleNP(), inner_tasks_, pkg_idx_,
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, K, B_stride);
@ -910,7 +948,7 @@ class MMPerPackage {
DecompressB(B, num_packed_B, row_b, range_K, B_view);
}
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_);
args_, C);
}
});
@ -918,8 +956,9 @@ class MMPerPackage {
}
// Single M, parallel N, sequential K. Fills all of partial.
template <typename TB>
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B) const {
template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone;
zone.MaybeEnter("MM.NT_K", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -942,13 +981,13 @@ class MMPerPackage {
zone.MaybeEnter("MM.NT_K.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
}
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag,
args_);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C);
}
};
args_.env->parallel.ForNP(
range_np_, MultipleNP(), inner_tasks_, pkg_idx_,
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
@ -965,13 +1004,13 @@ class MMPerPackage {
MMZone fill_zone;
if (out_ == MMOut::kCopy) {
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C);
} else if (out_ == MMOut::kParM) {
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
args_.env->parallel.ForRangeMC(
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
args_);
args_, C);
});
} else {
HWY_UNREACHABLE; // kDirect is only used with kNT.
@ -980,8 +1019,9 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TB>
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B) const {
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone;
zone.MaybeEnter("MM.NT_MT", args_);
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -1006,7 +1046,7 @@ class MMPerPackage {
DecompressB(B, num_packed_B, row_b, range_K, B_view);
}
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_);
args_, C);
}
});
@ -1015,8 +1055,9 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB>
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B) const {
template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B,
const RowPtr<TC>& C) const {
MMZone zone;
zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize();
@ -1039,8 +1080,8 @@ class MMPerPackage {
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
}
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag,
args_);
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C);
}
}; // loop_nc
args_.env->parallel.ForRangesMC_NC(
@ -1063,7 +1104,7 @@ class MMPerPackage {
HWY_DASSERT(out_ == MMOut::kCopy);
MMZone fill_zone;
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C);
});
}
@ -1083,6 +1124,8 @@ class MMPerPackage {
const IndexRange& range_K) HWY_ATTR {
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// otherwise, padding overwrites neighbors
HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols);
for (size_t row_a : range_M) {
const PackedSpan<const TA> from =
MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
@ -1224,9 +1267,10 @@ struct MMImpl {
// Called from `MatMul` from two places: either with the next autotune config,
// or with the best config.
template <typename TA, typename TB>
template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A,
const ConstMat<TB>& B, const MMArgs& args,
const ConstMat<TB>& B, const RowPtr<TC>& C,
const MMArgs& args,
const MMConfig& config) {
MMZone matmul_zone;
matmul_zone.MaybeEnter("MM.DoMatMul", args);
@ -1235,7 +1279,7 @@ struct MMImpl {
args.env->parallel.ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B);
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C);
});
}
};
@ -1260,10 +1304,10 @@ struct MMImpl {
// invalidated by the next call to `MatMul`.
//
// Uses considerable stack space: at least 40 KiB per thread.
template <typename TA, typename TB>
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtrF& C) {
const RowPtr<TC>& C) {
const size_t M = A.Extents().rows;
const size_t K = A.Extents().cols;
const size_t N = B.Extents().rows;
@ -1281,15 +1325,16 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
// invalidates `MMAutoTune::Best()`
index = env.per_key.size();
env.per_key.push_back(MMPerKey(max_packages, N, kNR, env.parallel));
env.per_key.push_back(
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
}
MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add,
env.storage.Partial(), C);
env.storage.Partial());
if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, args, *tuner.Best());
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
return &per_key;
}
@ -1306,13 +1351,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
HWY_ASSERT(N % kNR == 0);
// Negligible CPU time.
tuner.SetCandidates(MMCandidates(M, K, N, MMKernel::kMaxMR, kNR,
tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR,
per_key.ranges_np, env.print_config));
}
const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, args, cfg);
MMImpl::DoMatMul(A, B, C, args, cfg);
const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /

View File

@ -60,10 +60,13 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables.
class GenerateCandidates {
public:
GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr,
GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
: M_(M),
K_(K),
N_(N),
sizeof_TC_(sizeof_TC),
max_mr_(max_mr),
nr_(nr),
// These influence kc/nc, but are also stored in `MMConfig` for
@ -71,7 +74,7 @@ class GenerateCandidates {
// is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size.
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
nc_multiple_(Allocator::StepBytes() / sizeof(float)),
nc_multiple_(Allocator::StepBytes() / sizeof_TC),
ranges_np_(ranges_np),
print_config_(print_config) {}
@ -88,7 +91,7 @@ class GenerateCandidates {
for (size_t nc : NC(mr, mc, kc, order)) {
for (int inner_tasks : all_inner_tasks) {
for (MMOut out : all_outs) {
const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_,
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
nc_multiple_, order, out, inner_tasks);
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
@ -114,7 +117,7 @@ class GenerateCandidates {
private:
using SizeVec = std::vector<size_t>;
// How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may
// How many rows of A per call to `MMKernel::LoopKC`. Lower values may
// be better for SIMD targets with fewer registers.
SizeVec MR() const {
const int64_t target = hwy::DispatchedTarget();
@ -153,7 +156,7 @@ class GenerateCandidates {
// The number of A and B columns to read between updating `partial`.
SizeVec KC(size_t mr, MMOrder order) const {
// `LoopOverKC` handles up to `mr` rows of A.
// `LoopKC` handles up to `mr` rows of A.
const size_t rows_a = HWY_MIN(M_, mr);
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
@ -161,9 +164,9 @@ class GenerateCandidates {
// is important that B fits in L1, because batch=1 only has a single row of
// A and thus no reuse of the packed B. When L1-resident, we can use the
// separate `DecompressAndZeroPad` to write `kc` columns, rather than having
// to integrate `Decompress2` into `LoopOverKC`, which is less efficient for
// to integrate `Decompress2` into `LoopKC`, which is less efficient for
// TB=NUQ due to less amortization of the table loads. Due to the low L1
// latency, the packing is still effectively fused into `LoopOverKC`. It may
// latency, the packing is still effectively fused into `LoopKC`. It may
// be better to round up and accept a few L2 accesses in exchange for
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
// subtract the output and buf, and allow using more than the actual L1
@ -255,7 +258,7 @@ class GenerateCandidates {
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
const size_t np_max = ranges_np_.TaskSize();
size_t nc_max = np_max;
const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double);
const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double);
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded.
@ -350,6 +353,8 @@ class GenerateCandidates {
const size_t M_;
const size_t K_;
const size_t N_;
const size_t sizeof_TC_;
const size_t max_mr_;
const size_t nr_;
@ -365,23 +370,25 @@ class GenerateCandidates {
} // namespace
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
size_t nr,
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config) {
return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)();
return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np,
print_config)();
}
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package
// rows for that.
static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
size_t np_multiple = Allocator::QuantumBytes() / sizeof(float);
static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
size_t num_packages) {
size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
// choose a smaller multiple.
if (N % (np_multiple * num_packages)) {
const size_t min_multiple = Allocator::LineBytes() / sizeof(float);
const size_t min_multiple = Allocator::LineBytes() / sizeof_TC;
np_multiple =
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
if (HWY_UNLIKELY(np_multiple == 0)) {
@ -398,10 +405,10 @@ static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
}
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
size_t nr) const {
size_t sizeof_TC, size_t nr) const {
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
return StaticPartition(IndexRange(0, N), num_packages,
NPMultiple(N, nr, num_packages));
NPMultiple(N, sizeof_TC, nr, num_packages));
}
MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) {

View File

@ -60,7 +60,7 @@ class MMParallel {
// Initial static partitioning of B rows across packages.
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
size_t nr) const;
size_t sizeof_TC, size_t nr) const;
// For `BindB` and `BindC`.
size_t Node(size_t pkg_idx) const {
@ -170,13 +170,13 @@ class MMParallel {
NestedPools& pools_;
};
template <typename T> // float for C, double for partial
void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
template <typename TC> // BF16/float for C, double for partial
void BindC(size_t M, const RowPtr<TC>& C, MMParallel& parallel) {
if (!Allocator::ShouldBind()) return;
const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(T);
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(TC);
bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
@ -185,7 +185,7 @@ void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
// BindRowsToPackageNodes may not be page-aligned.
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T),
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
node);
}
}
@ -355,11 +355,11 @@ static inline const char* StringFromParA(MMParA par_a) {
class MMConfig {
public:
MMConfig() = default; // for std::vector
// `mr` is the number of A rows per call to `MMKernel::LoopOverKC`.
// `mr` is the number of A rows per call to `MMKernel::LoopKC`.
// `MMOrder` is how to parallelize the outer loops.
// `MMOut` is how/whether to parallelize filling the C result.
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc,
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
int inner_tasks)
: mr_(static_cast<uint32_t>(mr)),
@ -381,7 +381,7 @@ class MMConfig {
if (kc != K && (kc % kc_multiple) != 0) {
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
}
if (nc % nc_multiple != 0) {
if (nc != N && (nc % nc_multiple) != 0) {
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
}
HWY_DASSERT(StringFromOrder(order_) != nullptr);
@ -428,8 +428,8 @@ class MMConfig {
static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
size_t nr,
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config);
@ -588,8 +588,9 @@ class MMKeys {
// Per-MatMul-shape state.
struct MMPerKey {
MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel)
: ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {}
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
MMParallel& parallel)
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {}
// Only profile if enabled and the main autotuner finished (the par_a
// autotuner is per-package and we want to avoid synchronization).
@ -623,18 +624,16 @@ struct MatMulEnv {
std::vector<MMPerKey> per_key;
};
// Arguments to MatMul() that are independent of the A/B type.
// Arguments to MatMul() that are independent of the A/B/C types.
// Reduces register pressure compared to individual values/references.
struct MMArgs {
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
const float* HWY_RESTRICT add, const RowPtrD& partial,
const RowPtrF& C)
const float* HWY_RESTRICT add, const RowPtrD& partial)
: env(&env),
per_key(&per_key),
scale(scale),
add(add),
partial(partial),
C(C) {}
partial(partial) {}
MatMulEnv* env;
MMPerKey* per_key;
@ -643,7 +642,6 @@ struct MMArgs {
const float* HWY_RESTRICT add;
// Same size as C, threads write at false-sharing-free granularity.
RowPtrD partial;
RowPtrF C;
};
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
@ -683,22 +681,22 @@ struct MMZone {
// `ofs` required for compressed T.
template <typename T>
struct ConstMat {
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
: ptr(ptr), extents(extents), ofs(ofs) {
ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0)
: ptr(ptr), extents(extents), stride(stride), ofs(ofs) {
HWY_DASSERT(ptr != nullptr);
HWY_DASSERT(stride >= extents.cols);
}
// TODO: support stride for page alignment.
size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) {
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
}
}
return ofs + extents.cols * r;
return ofs + r * stride;
}
const Extents2D& Extents() const { return extents; }
size_t Stride() const { return extents.cols; }
size_t Stride() const { return stride; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0.
@ -709,6 +707,7 @@ struct ConstMat {
const T* HWY_RESTRICT ptr;
Extents2D extents;
size_t stride;
// `scale` allows expanding the smaller range of `SfpStream` to the original
// values. MatFromWeights sets this from `MatPtr`.
@ -721,9 +720,9 @@ struct ConstMat {
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride,
size_t ofs = 0) {
return ConstMat<T>(ptr, extents, ofs);
return ConstMat<T>(ptr, extents, stride, ofs);
}
// For A argument to MatMul (activations).
@ -732,22 +731,25 @@ ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols()));
Extents2D(batch_size, row_vectors.Cols()),
row_vectors.Stride());
}
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
ConstMat<T> mat =
MakeConstMat(const_cast<T*>(m.data()), m.Extents(), m.Stride(), ofs);
mat.scale = m.scale();
return mat;
}
template <typename TB>
void BindB(size_t N, const ConstMat<TB>& B, MMParallel& parallel) {
void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
MMParallel& parallel) {
if (!Allocator::ShouldBind()) return;
const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR);
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);

View File

@ -67,7 +67,7 @@ using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
@ -112,12 +112,12 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
}
// Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
double max_row_abs_sum = 0.0;
for (size_t r = 0; r < extents.rows; r++) {
const float* row = a + r * extents.cols;
for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a.Batch(r);
double row_abs_sum = 0.0;
for (size_t c = 0; c < extents.cols; c++) {
for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]);
}
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
@ -126,41 +126,52 @@ double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
}
// Returns the maximum absolute value of `a`.
float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) {
float MaxAbs(const RowVectorBatch<float>& a) {
float max_abs = 0.0f;
for (size_t c = 0; c < extents.cols; c++) {
for (size_t r = 0; r < extents.rows; r++) {
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c]));
for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a.Batch(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
}
}
return max_abs;
}
// B is already transposed.
template <typename TA, typename TB>
template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
const RowPtrF& C_slow, const RowPtrF& C) {
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const hn::ScalableTag<float> df;
const size_t num_a = A.extents.Area();
const size_t num_b = B.extents.Area();
const size_t N = hn::Lanes(df);
const size_t cols = A.extents.cols;
const size_t B_rows = B.extents.rows;
// Round up for DecompressAndZeroPad.
FloatPtr a = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_a, N));
FloatPtr b_trans = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_b, N));
HWY_ASSERT(a && b_trans);
RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
RowVectorBatch<float> c_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
RowVectorBatch<float> c_slow_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
for (size_t m = 0; m < A.extents.rows; ++m) {
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
a_batch.Batch(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Batch(m), B_rows);
}
for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0,
b_trans_batch.Batch(n), cols);
}
// MatMul rounds inputs to BF16, so error is proportional to the max input
// magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a.get(), A.Extents()) *
MaxRowAbsSum(b_trans.get(), B.Extents());
const float max_abs =
MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents());
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
double tolerance = 10 * norm * eps_f32;
double tolerance = 12 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there.
if (IsF32<TA>() && IsF32<TB>()) {
@ -169,30 +180,38 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
for (size_t r = 0; r < A.extents.rows; r++) {
const float* expected_row = C_slow.Row(r);
const float* actual_row = C.Row(r);
const float* expected_row = c_slow_batch.Batch(r);
const float* actual_row = c_batch.Batch(r);
for (size_t c = 0; c < B.extents.rows; c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
HWY_ABORT(
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f\n",
r, c, expected_value, actual_value, norm, max_abs, tolerance);
if (!in_range) {
const double max = HWY_MAX(expected_value, actual_value);
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E\n",
r, c, expected_value, actual_value, norm, max_abs,
tolerance, rel, max_rel);
}
}
}
}
}
// B is already transposed.
template <typename TA, typename TB>
template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
const float* HWY_RESTRICT add_row, MatMulEnv& env,
const RowPtrF& C) {
const RowPtr<TC>& C) {
// TA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs.
@ -200,7 +219,8 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
const float scale = A.scale * B.scale;
const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const TB> b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area());
const PackedSpan<const TB> b_span =
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols());
@ -219,12 +239,12 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
get_col_c, all_clusters,
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : rows_c) {
float* HWY_RESTRICT C_row = C.Row(r);
TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
C_row[c] =
add + scale * Dot(df, b_span, c * B.extents.cols,
A.ptr + A.Row(r), A.extents.cols);
C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.Stride(),
A.ptr + A.Row(r), A.extents.cols));
}
}
});
@ -239,14 +259,15 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
}
template <typename TA, typename TB = TA>
template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env) {
MatMulEnv& env, int line) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac,
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>());
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
TypeName<TC>());
env.print_config = true;
env.print_config = false; // Too verbose.
env.print_best = true;
const Extents2D A_extents(rows_ac, cols_a_rows_b);
@ -255,8 +276,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage;
@ -269,14 +290,14 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtrF C = RowPtrFromBatch(c_batch);
const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
MatMulSlow(A, B, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
AssertClose(A, B, C_slow, C);
AssertClose(A, B, C_slow, C, line);
if (per_key->autotune.Best()) break;
}
}
@ -311,7 +332,7 @@ void TestTiny() {
for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += max_packages * 4) {
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env);
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
}
}
}
@ -332,56 +353,69 @@ void TestAllMatMul() {
Allocator::Init(pools.Topology(), /*enable_bind=*/true);
MatMulEnv env(pools);
// Sizes seen in gemma_test 2B.
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env);
TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env);
TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env);
TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env);
TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env);
TestMatMul<F32>(5, 2048, 512, /*add=*/false, env);
TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env);
TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env);
TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env);
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 2048, 512, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env, __LINE__);
// medium-sized square
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
// medium-sized square, f32 vs bf16 for A, B, C; plus add.
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
// minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env);
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env);
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env);
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env);
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env);
TestMatMul<F32>(4, 128, 32, /*add=*/true, env);
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env);
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env);
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env);
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env);
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env);
TestMatMul<F32>(3, 128, 32, /*add=*/false, env);
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env);
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env);
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env);
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env);
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env);
TestMatMul<F32>(2, 128, 64, /*add=*/true, env);
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env);
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env);
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env);
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env);
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env);
TestMatMul<F32>(1, 128, 32, /*add=*/false, env);
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env);
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env);
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env);
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env);
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env);
TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)