mirror of https://github.com/google/gemma.cpp.git
Support bf16 output of Matmul
Adds Stride to ConstMat, to support decompression of C output for test matmul_test: add line numbers to output Also ignore "N is not a multiple of nc" when N==nc PiperOrigin-RevId: 731096662
This commit is contained in:
parent
b3b4b9f92f
commit
2bdf26d81d
|
|
@ -734,7 +734,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
auto activations_mat = MakeConstMat(
|
auto activations_mat = MakeConstMat(
|
||||||
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
|
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim),
|
||||||
|
hidden_activations.Stride());
|
||||||
|
|
||||||
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
||||||
}
|
}
|
||||||
|
|
@ -773,8 +774,9 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||||
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
auto activations_mat = MakeConstMat(
|
auto activations_mat = MakeConstMat(hidden_activations.Row(0),
|
||||||
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
|
Extents2D(num_interleaved, ff_hidden_dim),
|
||||||
|
hidden_activations.Stride());
|
||||||
|
|
||||||
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -133,21 +133,22 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||||
|
|
||||||
// Generates inputs and prints observed throughput of MatMul.
|
// Generates inputs and prints observed throughput of MatMul.
|
||||||
// M = A rows, K = A cols, N = C cols.
|
// M = A rows, K = A cols, N = C cols.
|
||||||
template <typename MatTA, typename MatTB = MatTA>
|
template <typename TA, typename TB = TA, typename TC = float>
|
||||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
||||||
if (env.print_config || env.print_measurement) {
|
if (env.print_config || env.print_measurement) {
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N,
|
fprintf(stderr,
|
||||||
add, TypeName<MatTA>(), TypeName<MatTB>());
|
"BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", //
|
||||||
|
M, K, N, add, TypeName<TA>(), TypeName<TB>(), TypeName<TC>());
|
||||||
|
|
||||||
const Extents2D A_extents(M, K);
|
const Extents2D A_extents(M, K);
|
||||||
const Extents2D B_extents(N, K); // already transposed
|
const Extents2D B_extents(N, K); // already transposed
|
||||||
const Extents2D C_extents(M, N);
|
const Extents2D C_extents(M, N);
|
||||||
|
|
||||||
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
|
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
|
||||||
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
|
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||||
if (add) {
|
if (add) {
|
||||||
|
|
@ -156,14 +157,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
add_storage->set_scale(1.0f);
|
add_storage->set_scale(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
|
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||||
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
|
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||||
HWY_ASSERT(a && b_trans);
|
HWY_ASSERT(a && b_trans);
|
||||||
const auto A = ConstMatFromWeights(*a);
|
const auto A = ConstMatFromWeights(*a);
|
||||||
const auto B = ConstMatFromWeights(*b_trans);
|
const auto B = ConstMatFromWeights(*b_trans);
|
||||||
|
|
||||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
// Fewer reps for large batch sizes, which take longer.
|
// Fewer reps for large batch sizes, which take longer.
|
||||||
const size_t num_samples = M < 32 ? 20 : 12;
|
const size_t num_samples = M < 32 ? 20 : 12;
|
||||||
|
|
@ -173,7 +174,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
// Ensure usage conditions are set before autotuning. Both binding and
|
// Ensure usage conditions are set before autotuning. Both binding and
|
||||||
// spinning may materially affect the choice of config. No harm in calling
|
// spinning may materially affect the choice of config. No harm in calling
|
||||||
// BindB/C if there is a single package: they will be a no-op.
|
// BindB/C if there is a single package: they will be a no-op.
|
||||||
BindB(B_extents.rows, B, env.parallel);
|
BindB(B_extents.rows, sizeof(TC), B, env.parallel);
|
||||||
BindC(A_extents.rows, C, env.parallel);
|
BindC(A_extents.rows, C, env.parallel);
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
|
|
@ -191,7 +192,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
per_key = MatMul(A, B, add_row, env, C);
|
per_key = MatMul(A, B, add_row, env, C);
|
||||||
const double t1 = hwy::platform::Now();
|
const double t1 = hwy::platform::Now();
|
||||||
double elapsed = t1 - t0;
|
double elapsed = t1 - t0;
|
||||||
keep += C.Row(0)[hwy::Unpredictable1()];
|
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);
|
||||||
|
|
||||||
// Only record times after autotuning finished.
|
// Only record times after autotuning finished.
|
||||||
if (per_key->autotune.Best()) times.push_back(elapsed);
|
if (per_key->autotune.Best()) times.push_back(elapsed);
|
||||||
|
|
@ -229,8 +230,8 @@ void BenchAllMatMul() {
|
||||||
|
|
||||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
for (size_t batch_size : {1, 4, 128, 512}) {
|
||||||
constexpr bool kAdd = false;
|
constexpr bool kAdd = false;
|
||||||
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
|
BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env);
|
||||||
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
|
BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
PROFILER_PRINT_RESULTS();
|
PROFILER_PRINT_RESULTS();
|
||||||
|
|
|
||||||
253
ops/matmul-inl.h
253
ops/matmul-inl.h
|
|
@ -46,7 +46,7 @@ namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register.
|
// Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register.
|
||||||
template <class DF, class DBF = hn::Repartition<hwy::bfloat16_t, DF>>
|
template <class DF, class DBF = hn::Repartition<BF16, DF>>
|
||||||
static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
|
static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
|
||||||
// Promoting odd means clearing the lower 16 bits. Doing this via AND
|
// Promoting odd means clearing the lower 16 bits. Doing this via AND
|
||||||
// requires a second input vector, which we prefer to avoid due to high
|
// requires a second input vector, which we prefer to avoid due to high
|
||||||
|
|
@ -70,6 +70,16 @@ static hn::VFromD<DF> FastPromoteOddTo(DF df, hn::VFromD<DBF> vbf) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts from float intermediate to MatMul output type `TC`.
|
||||||
|
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_F32_D(DC)>
|
||||||
|
hn::Vec<DC> TCFromF32(DC /*dc*/, hn::Vec<DF> vf) {
|
||||||
|
return vf;
|
||||||
|
}
|
||||||
|
template <class DC, class DF = hn::Rebind<float, DC>, HWY_IF_BF16_D(DC)>
|
||||||
|
hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
|
||||||
|
return hn::DemoteTo(dc, vf);
|
||||||
|
}
|
||||||
|
|
||||||
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
|
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one
|
||||||
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
|
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
|
||||||
// first kc result to partial, or accumulating the next kc result into partial
|
// first kc result to partial, or accumulating the next kc result into partial
|
||||||
|
|
@ -93,14 +103,14 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
// Thus we compute the horizontal sums of each `Crc`. The elements may be
|
// Thus we compute the horizontal sums of each `Crc`. The elements may be
|
||||||
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
|
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
|
||||||
// this does not change their horizontal sum.
|
// this does not change their horizontal sum.
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
template <class DF, class VF = hn::Vec<DF>, typename TC>
|
||||||
HWY_INLINE void operator()(DF df, //
|
HWY_INLINE void operator()(DF df, //
|
||||||
VF C00, VF C01, VF C02, VF C03, //
|
VF C00, VF C01, VF C02, VF C03, //
|
||||||
VF C10, VF C11, VF C12, VF C13, //
|
VF C10, VF C11, VF C12, VF C13, //
|
||||||
VF C20, VF C21, VF C22, VF C23, //
|
VF C20, VF C21, VF C22, VF C23, //
|
||||||
VF C30, VF C31, VF C32, VF C33, //
|
VF C30, VF C31, VF C32, VF C33, //
|
||||||
const size_t row_c, const size_t col_c,
|
const size_t row_c, const size_t col_c,
|
||||||
const MMArgs& args) const {
|
const MMArgs& args, const RowPtr<TC>& C) const {
|
||||||
float buf[16 * hn::MaxLanes(df)];
|
float buf[16 * hn::MaxLanes(df)];
|
||||||
const size_t N = hn::Lanes(df);
|
const size_t N = hn::Lanes(df);
|
||||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||||
|
|
@ -136,10 +146,10 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
if constexpr (kAdd) {
|
if constexpr (kAdd) {
|
||||||
vadd = hn::Load(d4, args.add + col_c);
|
vadd = hn::Load(d4, args.add + col_c);
|
||||||
}
|
}
|
||||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, args.C, row_c, col_c);
|
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c);
|
||||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, args.C, row_c, col_c);
|
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c);
|
||||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, args.C, row_c, col_c);
|
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c);
|
||||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, args.C, row_c, col_c);
|
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -155,34 +165,36 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
|
// Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4.
|
||||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
template <size_t kRow, class DF4, class VF4 = hn::Vec<DF4>>
|
||||||
static HWY_INLINE V4 MaybeLoad(D4 d4, size_t N,
|
static HWY_INLINE VF4 MaybeLoad(DF4 df4, size_t N,
|
||||||
const float* HWY_RESTRICT buf) {
|
const float* HWY_RESTRICT buf) {
|
||||||
if constexpr (kRow < kRowsAC) {
|
if constexpr (kRow < kRowsAC) {
|
||||||
return hn::Load(d4, buf + 4 * kRow * N);
|
return hn::Load(df4, buf + 4 * kRow * N);
|
||||||
} else {
|
} else {
|
||||||
return hn::Zero(d4);
|
return hn::Zero(df4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
template <size_t kRow, class DF4, class VF4 = hn::Vec<DF4>>
|
||||||
static HWY_INLINE V4 MaybeAdd(D4 d4, size_t N, V4 sum,
|
static HWY_INLINE VF4 MaybeAdd(DF4 df4, size_t N, VF4 sum,
|
||||||
const float* HWY_RESTRICT buf) {
|
const float* HWY_RESTRICT buf) {
|
||||||
if constexpr (kRow < kRowsAC) {
|
if constexpr (kRow < kRowsAC) {
|
||||||
return hn::Add(sum, hn::Load(d4, buf + 4 * kRow * N));
|
return hn::Add(sum, hn::Load(df4, buf + 4 * kRow * N));
|
||||||
} else {
|
} else {
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kRow, class D4, class V4 = hn::Vec<D4>>
|
template <size_t kRow, typename TC, class DF4, class VF4 = hn::Vec<DF4>>
|
||||||
static HWY_INLINE void MaybeScaleAndStore(D4 d4, V4 sum, V4 vscale, V4 vadd,
|
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||||
const RowPtrF& C,
|
VF4 vadd, const RowPtr<TC>& C,
|
||||||
const size_t row_c,
|
const size_t row_c,
|
||||||
const size_t col_c) {
|
const size_t col_c) {
|
||||||
if constexpr (kRow < kRowsAC) {
|
if constexpr (kRow < kRowsAC) {
|
||||||
float* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
|
TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c;
|
||||||
hn::Store(hn::MulAdd(sum, vscale, vadd), d4, pos);
|
const hn::Rebind<TC, DF4> dc4;
|
||||||
|
const VF4 out = hn::MulAdd(sum, vscale, vadd);
|
||||||
|
hn::Store(TCFromF32(dc4, out), dc4, pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}; // MMStoreHorizontalSumsIntoC
|
}; // MMStoreHorizontalSumsIntoC
|
||||||
|
|
@ -340,14 +352,14 @@ class MMKernel {
|
||||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||||
static constexpr size_t kMaxMR = 4;
|
static constexpr size_t kMaxMR = 4;
|
||||||
|
|
||||||
// Calls `LoopOverKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
||||||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||||
template <class Tag>
|
template <class Tag, typename TC>
|
||||||
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||||
size_t mr, const IndexRange& range_mc,
|
size_t mr, const IndexRange& range_mc,
|
||||||
const size_t row_b, size_t kc, Tag tag,
|
const size_t row_b, size_t kc, Tag tag,
|
||||||
const MMArgs& args) {
|
const MMArgs& args, const RowPtr<TC>& C) {
|
||||||
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
HWY_DASSERT(1 <= mr && mr <= kMaxMR);
|
||||||
const size_t row0 = range_mc.begin();
|
const size_t row0 = range_mc.begin();
|
||||||
const size_t mc = range_mc.Num();
|
const size_t mc = range_mc.Num();
|
||||||
|
|
@ -356,7 +368,7 @@ class MMKernel {
|
||||||
// M == 1, or x86 with 8 SIMD registers:
|
// M == 1, or x86 with 8 SIMD registers:
|
||||||
if (HWY_UNLIKELY(mr == 1)) {
|
if (HWY_UNLIKELY(mr == 1)) {
|
||||||
for (; imc < mc; ++imc) {
|
for (; imc < mc; ++imc) {
|
||||||
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
|
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -365,11 +377,11 @@ class MMKernel {
|
||||||
if (HWY_UNLIKELY(mr == 2)) {
|
if (HWY_UNLIKELY(mr == 2)) {
|
||||||
if (HWY_LIKELY(mc >= 2)) {
|
if (HWY_LIKELY(mc >= 2)) {
|
||||||
for (; imc <= mc - 2; imc += 2) {
|
for (; imc <= mc - 2; imc += 2) {
|
||||||
LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
|
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(imc != mc)) {
|
if (HWY_UNLIKELY(imc != mc)) {
|
||||||
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
|
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -377,17 +389,17 @@ class MMKernel {
|
||||||
HWY_DASSERT(mr == 4);
|
HWY_DASSERT(mr == 4);
|
||||||
if (HWY_LIKELY(mc >= 4)) {
|
if (HWY_LIKELY(mc >= 4)) {
|
||||||
for (; imc <= mc - 4; imc += 4) {
|
for (; imc <= mc - 4; imc += 4) {
|
||||||
LoopOverKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
|
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const size_t remainder_mc = mc - imc;
|
const size_t remainder_mc = mc - imc;
|
||||||
HWY_DASSERT(remainder_mc < 4);
|
HWY_DASSERT(remainder_mc < 4);
|
||||||
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
if (HWY_UNLIKELY(remainder_mc & 2)) {
|
||||||
LoopOverKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
|
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||||
imc += 2;
|
imc += 2;
|
||||||
}
|
}
|
||||||
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
if (HWY_UNLIKELY(remainder_mc & 1)) {
|
||||||
LoopOverKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args);
|
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C);
|
||||||
imc += 1;
|
imc += 1;
|
||||||
}
|
}
|
||||||
HWY_DASSERT(imc == mc);
|
HWY_DASSERT(imc == mc);
|
||||||
|
|
@ -484,11 +496,11 @@ class MMKernel {
|
||||||
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
|
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
|
||||||
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
// BF16 so we can load directly without `Decompress2`, which is expensive for
|
||||||
// NUQ and requires 2x unrolling, which requires more loads.
|
// NUQ and requires 2x unrolling, which requires more loads.
|
||||||
template <size_t kRowsAC, class Tag>
|
template <size_t kRowsAC, class Tag, typename TC>
|
||||||
static HWY_INLINE void LoopOverKC(const RowPtrBF& A_view,
|
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
|
||||||
const RowPtrBF& B_view, size_t row_ac,
|
size_t row_ac, size_t imc, size_t col_c,
|
||||||
size_t imc, size_t col_c, size_t kc,
|
size_t kc, Tag tag, const MMArgs& args,
|
||||||
Tag tag, const MMArgs& args) {
|
const RowPtr<TC>& C) {
|
||||||
const hn::ScalableTag<BF16> dbf;
|
const hn::ScalableTag<BF16> dbf;
|
||||||
using VBF = hn::Vec<decltype(dbf)>;
|
using VBF = hn::Vec<decltype(dbf)>;
|
||||||
const size_t NBF = hn::Lanes(dbf);
|
const size_t NBF = hn::Lanes(dbf);
|
||||||
|
|
@ -602,11 +614,11 @@ class MMKernel {
|
||||||
if (args.add) {
|
if (args.add) {
|
||||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
|
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
|
||||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||||
C31, C32, C33, row_ac, col_c, args);
|
C31, C32, C33, row_ac, col_c, args, C);
|
||||||
} else {
|
} else {
|
||||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
|
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
|
||||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||||
C31, C32, C33, row_ac, col_c, args);
|
C31, C32, C33, row_ac, col_c, args, C);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
|
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
|
||||||
|
|
@ -627,40 +639,44 @@ class MMScaleDemoteAdd {
|
||||||
// TODO: fuse with subsequent operations - function pointer?
|
// TODO: fuse with subsequent operations - function pointer?
|
||||||
// Although this region in `outputs.C` is not touched again, streaming stores
|
// Although this region in `outputs.C` is not touched again, streaming stores
|
||||||
// do not help on SKX and Zen4. TODO: re-check this.
|
// do not help on SKX and Zen4. TODO: re-check this.
|
||||||
|
template <typename TC>
|
||||||
static HWY_INLINE void FillC(const IndexRange& range_mc,
|
static HWY_INLINE void FillC(const IndexRange& range_mc,
|
||||||
const IndexRange& range_nc, const MMArgs& args) {
|
const IndexRange& range_nc, const MMArgs& args,
|
||||||
|
const RowPtr<TC>& C) {
|
||||||
size_t row_c = range_mc.begin();
|
size_t row_c = range_mc.begin();
|
||||||
if (args.add) {
|
if (args.add) {
|
||||||
constexpr bool kAdd = true;
|
constexpr bool kAdd = true;
|
||||||
if (range_mc.Num() >= 4) {
|
if (range_mc.Num() >= 4) {
|
||||||
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
||||||
Do4Rows<kAdd>(row_c, range_nc, args);
|
Do4Rows<kAdd>(row_c, range_nc, args, C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; row_c < range_mc.end(); ++row_c) {
|
for (; row_c < range_mc.end(); ++row_c) {
|
||||||
Do1Row<kAdd>(row_c, range_nc, args);
|
Do1Row<kAdd>(row_c, range_nc, args, C);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
constexpr bool kAdd = false;
|
constexpr bool kAdd = false;
|
||||||
if (range_mc.Num() >= 4) {
|
if (range_mc.Num() >= 4) {
|
||||||
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
for (; row_c <= range_mc.end() - 4; row_c += 4) {
|
||||||
Do4Rows<kAdd>(row_c, range_nc, args);
|
Do4Rows<kAdd>(row_c, range_nc, args, C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (; row_c < range_mc.end(); ++row_c) {
|
for (; row_c < range_mc.end(); ++row_c) {
|
||||||
Do1Row<kAdd>(row_c, range_nc, args);
|
Do1Row<kAdd>(row_c, range_nc, args, C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Unrolled for 4 rows to reduce the number of loads from `add`.
|
// Unrolled for 4 rows to reduce the number of loads from `add`.
|
||||||
template <bool kAdd>
|
template <bool kAdd, typename TC>
|
||||||
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
|
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
|
||||||
const MMArgs& args) {
|
const MMArgs& args, const RowPtr<TC>& C) {
|
||||||
const hn::ScalableTag<double> dd;
|
const hn::ScalableTag<double> dd;
|
||||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||||
|
const hn::Rebind<TC, decltype(dd)> dc;
|
||||||
using VD = hn::Vec<decltype(dd)>;
|
using VD = hn::Vec<decltype(dd)>;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const size_t ND = hn::Lanes(dd);
|
const size_t ND = hn::Lanes(dd);
|
||||||
const VD vscale = hn::Set(dd, args.scale);
|
const VD vscale = hn::Set(dd, args.scale);
|
||||||
|
|
||||||
|
|
@ -669,10 +685,10 @@ class MMScaleDemoteAdd {
|
||||||
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
|
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
|
||||||
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
|
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
|
||||||
|
|
||||||
float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0);
|
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
|
||||||
float* HWY_RESTRICT cr1 = args.C.Row(row_c + 1);
|
TC* HWY_RESTRICT cr1 = C.Row(row_c + 1);
|
||||||
float* HWY_RESTRICT cr2 = args.C.Row(row_c + 2);
|
TC* HWY_RESTRICT cr2 = C.Row(row_c + 2);
|
||||||
float* HWY_RESTRICT cr3 = args.C.Row(row_c + 3);
|
TC* HWY_RESTRICT cr3 = C.Row(row_c + 3);
|
||||||
|
|
||||||
// We manually unroll 2x for higher IPC in batch=1.
|
// We manually unroll 2x for higher IPC in batch=1.
|
||||||
size_t col_c = range_nc.begin();
|
size_t col_c = range_nc.begin();
|
||||||
|
|
@ -713,15 +729,24 @@ class MMScaleDemoteAdd {
|
||||||
m30 = hn::Mul(d30, vscale);
|
m30 = hn::Mul(d30, vscale);
|
||||||
m31 = hn::Mul(d31, vscale);
|
m31 = hn::Mul(d31, vscale);
|
||||||
}
|
}
|
||||||
|
// First convert f64 to f32.
|
||||||
|
const VF f00 = hn::DemoteTo(df, m00);
|
||||||
|
const VF f01 = hn::DemoteTo(df, m01);
|
||||||
|
const VF f10 = hn::DemoteTo(df, m10);
|
||||||
|
const VF f11 = hn::DemoteTo(df, m11);
|
||||||
|
const VF f20 = hn::DemoteTo(df, m20);
|
||||||
|
const VF f21 = hn::DemoteTo(df, m21);
|
||||||
|
const VF f30 = hn::DemoteTo(df, m30);
|
||||||
|
const VF f31 = hn::DemoteTo(df, m31);
|
||||||
// Note that Stream is neutral on SKX and harmful on Zen4.
|
// Note that Stream is neutral on SKX and harmful on Zen4.
|
||||||
hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c);
|
hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
|
||||||
hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND);
|
hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
|
||||||
hn::Store(hn::DemoteTo(df, m10), df, cr1 + col_c);
|
hn::Store(TCFromF32(dc, f10), dc, cr1 + col_c);
|
||||||
hn::Store(hn::DemoteTo(df, m11), df, cr1 + col_c + ND);
|
hn::Store(TCFromF32(dc, f11), dc, cr1 + col_c + ND);
|
||||||
hn::Store(hn::DemoteTo(df, m20), df, cr2 + col_c);
|
hn::Store(TCFromF32(dc, f20), dc, cr2 + col_c);
|
||||||
hn::Store(hn::DemoteTo(df, m21), df, cr2 + col_c + ND);
|
hn::Store(TCFromF32(dc, f21), dc, cr2 + col_c + ND);
|
||||||
hn::Store(hn::DemoteTo(df, m30), df, cr3 + col_c);
|
hn::Store(TCFromF32(dc, f30), dc, cr3 + col_c);
|
||||||
hn::Store(hn::DemoteTo(df, m31), df, cr3 + col_c + ND);
|
hn::Store(TCFromF32(dc, f31), dc, cr3 + col_c + ND);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -750,25 +775,31 @@ class MMScaleDemoteAdd {
|
||||||
m20 = hn::Mul(d20, vscale);
|
m20 = hn::Mul(d20, vscale);
|
||||||
m30 = hn::Mul(d30, vscale);
|
m30 = hn::Mul(d30, vscale);
|
||||||
}
|
}
|
||||||
hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining);
|
// First convert f64 to f32.
|
||||||
hn::StoreN(hn::DemoteTo(df, m10), df, cr1 + col_c, remaining);
|
const VF f00 = hn::DemoteTo(df, m00);
|
||||||
hn::StoreN(hn::DemoteTo(df, m20), df, cr2 + col_c, remaining);
|
const VF f10 = hn::DemoteTo(df, m10);
|
||||||
hn::StoreN(hn::DemoteTo(df, m30), df, cr3 + col_c, remaining);
|
const VF f20 = hn::DemoteTo(df, m20);
|
||||||
|
const VF f30 = hn::DemoteTo(df, m30);
|
||||||
|
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
|
||||||
|
hn::StoreN(TCFromF32(dc, f10), dc, cr1 + col_c, remaining);
|
||||||
|
hn::StoreN(TCFromF32(dc, f20), dc, cr2 + col_c, remaining);
|
||||||
|
hn::StoreN(TCFromF32(dc, f30), dc, cr3 + col_c, remaining);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as above but handles a single row (for remainder rows).
|
// Same as above but handles a single row (for remainder rows).
|
||||||
template <bool kAdd>
|
template <bool kAdd, typename TC>
|
||||||
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
|
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
|
||||||
const MMArgs& args) {
|
const MMArgs& args, const RowPtr<TC>& C) {
|
||||||
const hn::ScalableTag<double> dd;
|
const hn::ScalableTag<double> dd;
|
||||||
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
|
||||||
|
const hn::Rebind<TC, decltype(dd)> dc;
|
||||||
using VD = hn::Vec<decltype(dd)>;
|
using VD = hn::Vec<decltype(dd)>;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
const size_t ND = hn::Lanes(dd);
|
const size_t ND = hn::Lanes(dd);
|
||||||
const VD vscale = hn::Set(dd, args.scale);
|
const VD vscale = hn::Set(dd, args.scale);
|
||||||
|
|
||||||
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
|
||||||
float* HWY_RESTRICT cr0 = args.C.Row(row_c + 0);
|
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0);
|
||||||
|
|
||||||
// We manually unroll 2x for higher IPC in batch=1.
|
// We manually unroll 2x for higher IPC in batch=1.
|
||||||
size_t col_c = range_nc.begin();
|
size_t col_c = range_nc.begin();
|
||||||
|
|
@ -791,9 +822,12 @@ class MMScaleDemoteAdd {
|
||||||
m00 = hn::Mul(d00, vscale);
|
m00 = hn::Mul(d00, vscale);
|
||||||
m01 = hn::Mul(d01, vscale);
|
m01 = hn::Mul(d01, vscale);
|
||||||
}
|
}
|
||||||
|
// First convert f64 to f32.
|
||||||
|
const VF f00 = hn::DemoteTo(df, m00);
|
||||||
|
const VF f01 = hn::DemoteTo(df, m01);
|
||||||
// Note that Stream is neutral on SKX and harmful on Zen4.
|
// Note that Stream is neutral on SKX and harmful on Zen4.
|
||||||
hn::Store(hn::DemoteTo(df, m00), df, cr0 + col_c);
|
hn::Store(TCFromF32(dc, f00), dc, cr0 + col_c);
|
||||||
hn::Store(hn::DemoteTo(df, m01), df, cr0 + col_c + ND);
|
hn::Store(TCFromF32(dc, f01), dc, cr0 + col_c + ND);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -813,7 +847,9 @@ class MMScaleDemoteAdd {
|
||||||
} else {
|
} else {
|
||||||
m00 = hn::Mul(d00, vscale);
|
m00 = hn::Mul(d00, vscale);
|
||||||
}
|
}
|
||||||
hn::StoreN(hn::DemoteTo(df, m00), df, cr0 + col_c, remaining);
|
// First convert f64 to f32.
|
||||||
|
const VF f00 = hn::DemoteTo(df, m00);
|
||||||
|
hn::StoreN(TCFromF32(dc, f00), dc, cr0 + col_c, remaining);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}; // MMScaleDemoteAdd
|
}; // MMScaleDemoteAdd
|
||||||
|
|
@ -849,20 +885,21 @@ class MMPerPackage {
|
||||||
|
|
||||||
// B is decompressed several call layers lower, but not all member functions
|
// B is decompressed several call layers lower, but not all member functions
|
||||||
// depend on TB, so pass it as an argument instead of templating the class.
|
// depend on TB, so pass it as an argument instead of templating the class.
|
||||||
template <typename TB>
|
template <typename TB, typename TC>
|
||||||
HWY_NOINLINE void operator()(const ConstMat<TB>& B) const {
|
HWY_NOINLINE void operator()(const ConstMat<TB>& B,
|
||||||
|
const RowPtr<TC>& C) const {
|
||||||
// TODO: include NUQ tables? NumPacked in ConstMat?
|
// TODO: include NUQ tables? NumPacked in ConstMat?
|
||||||
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
|
const size_t num_packed_B = B.ofs + B.Stride() * B.Extents().rows;
|
||||||
|
|
||||||
switch (order_) {
|
switch (order_) {
|
||||||
case MMOrder::kNT:
|
case MMOrder::kNT:
|
||||||
return DoNT(B, num_packed_B);
|
return DoNT(B, num_packed_B, C);
|
||||||
case MMOrder::kNT_K:
|
case MMOrder::kNT_K:
|
||||||
return DoNT_K(B, num_packed_B);
|
return DoNT_K(B, num_packed_B, C);
|
||||||
case MMOrder::kNT_MT:
|
case MMOrder::kNT_MT:
|
||||||
return DoNT_MT(B, num_packed_B);
|
return DoNT_MT(B, num_packed_B, C);
|
||||||
case MMOrder::kNT_MT_K:
|
case MMOrder::kNT_MT_K:
|
||||||
return DoNT_MT_K(B, num_packed_B);
|
return DoNT_MT_K(B, num_packed_B, C);
|
||||||
default:
|
default:
|
||||||
HWY_UNREACHABLE;
|
HWY_UNREACHABLE;
|
||||||
}
|
}
|
||||||
|
|
@ -878,13 +915,14 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Granularity of `ForNP`. B rows produce C columns, so we
|
// Granularity of `ForNP`. B rows produce C columns, so we
|
||||||
// want a multiple of the line size to prevent false sharing.
|
// want a multiple of the line size to prevent false sharing.
|
||||||
static size_t MultipleNP() {
|
static size_t MultipleNP(size_t sizeof_TC) {
|
||||||
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof(float));
|
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M and K, parallel N. Fills all of C directly.
|
// Single M and K, parallel N. Fills all of C directly.
|
||||||
template <typename TB>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B) const {
|
HWY_INLINE void DoNT(const ConstMat<TB>& B, size_t num_packed_B,
|
||||||
|
const RowPtr<TC>& C) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT", args_);
|
zone.MaybeEnter("MM.NT", args_);
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -897,7 +935,7 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||||
args_.env->parallel.ForNP(
|
args_.env->parallel.ForNP(
|
||||||
range_np_, MultipleNP(), inner_tasks_, pkg_idx_,
|
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||||
|
|
@ -910,7 +948,7 @@ class MMPerPackage {
|
||||||
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
|
||||||
args_);
|
args_, C);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -918,8 +956,9 @@ class MMPerPackage {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M, parallel N, sequential K. Fills all of partial.
|
// Single M, parallel N, sequential K. Fills all of partial.
|
||||||
template <typename TB>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B) const {
|
HWY_INLINE void DoNT_K(const ConstMat<TB>& B, size_t num_packed_B,
|
||||||
|
const RowPtr<TC>& C) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_K", args_);
|
zone.MaybeEnter("MM.NT_K", args_);
|
||||||
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_mc_.NumTasks() == 1);
|
||||||
|
|
@ -942,13 +981,13 @@ class MMPerPackage {
|
||||||
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
zone.MaybeEnter("MM.NT_K.DecB", args_);
|
||||||
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
args_);
|
C);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
args_.env->parallel.ForNP(
|
args_.env->parallel.ForNP(
|
||||||
range_np_, MultipleNP(), inner_tasks_, pkg_idx_,
|
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
|
|
||||||
|
|
@ -965,13 +1004,13 @@ class MMPerPackage {
|
||||||
MMZone fill_zone;
|
MMZone fill_zone;
|
||||||
if (out_ == MMOut::kCopy) {
|
if (out_ == MMOut::kCopy) {
|
||||||
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
|
fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
|
||||||
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_);
|
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C);
|
||||||
} else if (out_ == MMOut::kParM) {
|
} else if (out_ == MMOut::kParM) {
|
||||||
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
|
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
|
||||||
args_.env->parallel.ForRangeMC(
|
args_.env->parallel.ForRangeMC(
|
||||||
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
|
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
|
||||||
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
|
||||||
args_);
|
args_, C);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
HWY_UNREACHABLE; // kDirect is only used with kNT.
|
HWY_UNREACHABLE; // kDirect is only used with kNT.
|
||||||
|
|
@ -980,8 +1019,9 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
// Parallel loops over mc/nc blocks of M/range_np, single K.
|
||||||
// Fills `mc x nc` sections of C directly, in parallel.
|
// Fills `mc x nc` sections of C directly, in parallel.
|
||||||
template <typename TB>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B) const {
|
HWY_INLINE void DoNT_MT(const ConstMat<TB>& B, size_t num_packed_B,
|
||||||
|
const RowPtr<TC>& C) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT", args_);
|
zone.MaybeEnter("MM.NT_MT", args_);
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
|
|
@ -1006,7 +1046,7 @@ class MMPerPackage {
|
||||||
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
DecompressB(B, num_packed_B, row_b, range_K, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
|
||||||
args_);
|
args_, C);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -1015,8 +1055,9 @@ class MMPerPackage {
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
||||||
template <typename TB>
|
template <typename TB, typename TC>
|
||||||
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B) const {
|
HWY_INLINE void DoNT_MT_K(const ConstMat<TB>& B, size_t num_packed_B,
|
||||||
|
const RowPtr<TC>& C) const {
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
|
|
@ -1039,8 +1080,8 @@ class MMPerPackage {
|
||||||
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
zone.MaybeEnter("MM.NT_MT_K.DecB", args_);
|
||||||
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
DecompressB(B, num_packed_B, row_b, range_kc, B_view);
|
||||||
}
|
}
|
||||||
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag,
|
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
|
||||||
args_);
|
C);
|
||||||
}
|
}
|
||||||
}; // loop_nc
|
}; // loop_nc
|
||||||
args_.env->parallel.ForRangesMC_NC(
|
args_.env->parallel.ForRangesMC_NC(
|
||||||
|
|
@ -1063,7 +1104,7 @@ class MMPerPackage {
|
||||||
HWY_DASSERT(out_ == MMOut::kCopy);
|
HWY_DASSERT(out_ == MMOut::kCopy);
|
||||||
MMZone fill_zone;
|
MMZone fill_zone;
|
||||||
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
|
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
|
||||||
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_);
|
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1083,6 +1124,8 @@ class MMPerPackage {
|
||||||
const IndexRange& range_K) HWY_ATTR {
|
const IndexRange& range_K) HWY_ATTR {
|
||||||
const size_t col0 = range_K.begin();
|
const size_t col0 = range_K.begin();
|
||||||
const size_t cols = range_K.Num();
|
const size_t cols = range_K.Num();
|
||||||
|
// otherwise, padding overwrites neighbors
|
||||||
|
HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols);
|
||||||
for (size_t row_a : range_M) {
|
for (size_t row_a : range_M) {
|
||||||
const PackedSpan<const TA> from =
|
const PackedSpan<const TA> from =
|
||||||
MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
|
MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
|
||||||
|
|
@ -1224,9 +1267,10 @@ struct MMImpl {
|
||||||
|
|
||||||
// Called from `MatMul` from two places: either with the next autotune config,
|
// Called from `MatMul` from two places: either with the next autotune config,
|
||||||
// or with the best config.
|
// or with the best config.
|
||||||
template <typename TA, typename TB>
|
template <typename TA, typename TB, typename TC>
|
||||||
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A,
|
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A,
|
||||||
const ConstMat<TB>& B, const MMArgs& args,
|
const ConstMat<TB>& B, const RowPtr<TC>& C,
|
||||||
|
const MMArgs& args,
|
||||||
const MMConfig& config) {
|
const MMConfig& config) {
|
||||||
MMZone matmul_zone;
|
MMZone matmul_zone;
|
||||||
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
matmul_zone.MaybeEnter("MM.DoMatMul", args);
|
||||||
|
|
@ -1235,7 +1279,7 @@ struct MMImpl {
|
||||||
args.env->parallel.ForPkg(
|
args.env->parallel.ForPkg(
|
||||||
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
|
||||||
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
|
||||||
MMPerPackage(A, args, config, pkg_idx, range_np)(B);
|
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1260,10 +1304,10 @@ struct MMImpl {
|
||||||
// invalidated by the next call to `MatMul`.
|
// invalidated by the next call to `MatMul`.
|
||||||
//
|
//
|
||||||
// Uses considerable stack space: at least 40 KiB per thread.
|
// Uses considerable stack space: at least 40 KiB per thread.
|
||||||
template <typename TA, typename TB>
|
template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
const RowPtrF& C) {
|
const RowPtr<TC>& C) {
|
||||||
const size_t M = A.Extents().rows;
|
const size_t M = A.Extents().rows;
|
||||||
const size_t K = A.Extents().cols;
|
const size_t K = A.Extents().cols;
|
||||||
const size_t N = B.Extents().rows;
|
const size_t N = B.Extents().rows;
|
||||||
|
|
@ -1281,15 +1325,16 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
|
|
||||||
// invalidates `MMAutoTune::Best()`
|
// invalidates `MMAutoTune::Best()`
|
||||||
index = env.per_key.size();
|
index = env.per_key.size();
|
||||||
env.per_key.push_back(MMPerKey(max_packages, N, kNR, env.parallel));
|
env.per_key.push_back(
|
||||||
|
MMPerKey(max_packages, N, sizeof(TC), kNR, env.parallel));
|
||||||
}
|
}
|
||||||
MMPerKey& per_key = env.per_key[index];
|
MMPerKey& per_key = env.per_key[index];
|
||||||
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
MMAutoTune<MMConfig>& tuner = per_key.autotune;
|
||||||
|
|
||||||
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add,
|
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add,
|
||||||
env.storage.Partial(), C);
|
env.storage.Partial());
|
||||||
if (HWY_LIKELY(tuner.Best())) {
|
if (HWY_LIKELY(tuner.Best())) {
|
||||||
MMImpl::DoMatMul(A, B, args, *tuner.Best());
|
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
|
||||||
return &per_key;
|
return &per_key;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1306,13 +1351,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
HWY_ASSERT(N % kNR == 0);
|
HWY_ASSERT(N % kNR == 0);
|
||||||
|
|
||||||
// Negligible CPU time.
|
// Negligible CPU time.
|
||||||
tuner.SetCandidates(MMCandidates(M, K, N, MMKernel::kMaxMR, kNR,
|
tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR,
|
||||||
per_key.ranges_np, env.print_config));
|
per_key.ranges_np, env.print_config));
|
||||||
}
|
}
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
const uint64_t t0 = hwy::timer::Start();
|
const uint64_t t0 = hwy::timer::Start();
|
||||||
MMImpl::DoMatMul(A, B, args, cfg);
|
MMImpl::DoMatMul(A, B, C, args, cfg);
|
||||||
const uint64_t t1 =
|
const uint64_t t1 =
|
||||||
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
|
||||||
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
|
||||||
|
|
|
||||||
|
|
@ -60,10 +60,13 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
||||||
// and holds most of their arguments in member variables.
|
// and holds most of their arguments in member variables.
|
||||||
class GenerateCandidates {
|
class GenerateCandidates {
|
||||||
public:
|
public:
|
||||||
GenerateCandidates(size_t M, size_t K, size_t N, size_t max_mr, size_t nr,
|
GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC,
|
||||||
|
size_t max_mr, size_t nr,
|
||||||
const IndexRangePartition& ranges_np, bool print_config)
|
const IndexRangePartition& ranges_np, bool print_config)
|
||||||
: M_(M),
|
: M_(M),
|
||||||
K_(K),
|
K_(K),
|
||||||
|
N_(N),
|
||||||
|
sizeof_TC_(sizeof_TC),
|
||||||
max_mr_(max_mr),
|
max_mr_(max_mr),
|
||||||
nr_(nr),
|
nr_(nr),
|
||||||
// These influence kc/nc, but are also stored in `MMConfig` for
|
// These influence kc/nc, but are also stored in `MMConfig` for
|
||||||
|
|
@ -71,7 +74,7 @@ class GenerateCandidates {
|
||||||
// is likely still in L1, but we expect K > 1000 and might as well round
|
// is likely still in L1, but we expect K > 1000 and might as well round
|
||||||
// up to the line size.
|
// up to the line size.
|
||||||
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
|
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
|
||||||
nc_multiple_(Allocator::StepBytes() / sizeof(float)),
|
nc_multiple_(Allocator::StepBytes() / sizeof_TC),
|
||||||
ranges_np_(ranges_np),
|
ranges_np_(ranges_np),
|
||||||
print_config_(print_config) {}
|
print_config_(print_config) {}
|
||||||
|
|
||||||
|
|
@ -88,7 +91,7 @@ class GenerateCandidates {
|
||||||
for (size_t nc : NC(mr, mc, kc, order)) {
|
for (size_t nc : NC(mr, mc, kc, order)) {
|
||||||
for (int inner_tasks : all_inner_tasks) {
|
for (int inner_tasks : all_inner_tasks) {
|
||||||
for (MMOut out : all_outs) {
|
for (MMOut out : all_outs) {
|
||||||
const MMConfig config(K_, mr, mc, kc, nc, kc_multiple_,
|
const MMConfig config(K_, N_, mr, mc, kc, nc, kc_multiple_,
|
||||||
nc_multiple_, order, out, inner_tasks);
|
nc_multiple_, order, out, inner_tasks);
|
||||||
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
|
const size_t M_tasks = config.RangesOfMC(M_).NumTasks();
|
||||||
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
|
const size_t K_tasks = config.RangesOfKC(K_).NumTasks();
|
||||||
|
|
@ -114,7 +117,7 @@ class GenerateCandidates {
|
||||||
private:
|
private:
|
||||||
using SizeVec = std::vector<size_t>;
|
using SizeVec = std::vector<size_t>;
|
||||||
|
|
||||||
// How many rows of A per call to `MMKernel::LoopOverKC`. Lower values may
|
// How many rows of A per call to `MMKernel::LoopKC`. Lower values may
|
||||||
// be better for SIMD targets with fewer registers.
|
// be better for SIMD targets with fewer registers.
|
||||||
SizeVec MR() const {
|
SizeVec MR() const {
|
||||||
const int64_t target = hwy::DispatchedTarget();
|
const int64_t target = hwy::DispatchedTarget();
|
||||||
|
|
@ -153,7 +156,7 @@ class GenerateCandidates {
|
||||||
|
|
||||||
// The number of A and B columns to read between updating `partial`.
|
// The number of A and B columns to read between updating `partial`.
|
||||||
SizeVec KC(size_t mr, MMOrder order) const {
|
SizeVec KC(size_t mr, MMOrder order) const {
|
||||||
// `LoopOverKC` handles up to `mr` rows of A.
|
// `LoopKC` handles up to `mr` rows of A.
|
||||||
const size_t rows_a = HWY_MIN(M_, mr);
|
const size_t rows_a = HWY_MIN(M_, mr);
|
||||||
|
|
||||||
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
|
// After looping over `kc` columns, we write `mr x 4` outputs and 16 vector
|
||||||
|
|
@ -161,9 +164,9 @@ class GenerateCandidates {
|
||||||
// is important that B fits in L1, because batch=1 only has a single row of
|
// is important that B fits in L1, because batch=1 only has a single row of
|
||||||
// A and thus no reuse of the packed B. When L1-resident, we can use the
|
// A and thus no reuse of the packed B. When L1-resident, we can use the
|
||||||
// separate `DecompressAndZeroPad` to write `kc` columns, rather than having
|
// separate `DecompressAndZeroPad` to write `kc` columns, rather than having
|
||||||
// to integrate `Decompress2` into `LoopOverKC`, which is less efficient for
|
// to integrate `Decompress2` into `LoopKC`, which is less efficient for
|
||||||
// TB=NUQ due to less amortization of the table loads. Due to the low L1
|
// TB=NUQ due to less amortization of the table loads. Due to the low L1
|
||||||
// latency, the packing is still effectively fused into `LoopOverKC`. It may
|
// latency, the packing is still effectively fused into `LoopKC`. It may
|
||||||
// be better to round up and accept a few L2 accesses in exchange for
|
// be better to round up and accept a few L2 accesses in exchange for
|
||||||
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
|
// fewer loops over K, and thus fewer writes to `partial`. Hence we do not
|
||||||
// subtract the output and buf, and allow using more than the actual L1
|
// subtract the output and buf, and allow using more than the actual L1
|
||||||
|
|
@ -255,7 +258,7 @@ class GenerateCandidates {
|
||||||
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
|
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
|
||||||
const size_t np_max = ranges_np_.TaskSize();
|
const size_t np_max = ranges_np_.TaskSize();
|
||||||
size_t nc_max = np_max;
|
size_t nc_max = np_max;
|
||||||
const size_t out_bytes = IsOneKC(order) ? sizeof(float) : sizeof(double);
|
const size_t out_bytes = IsOneKC(order) ? sizeof_TC_ : sizeof(double);
|
||||||
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
|
||||||
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
|
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
|
||||||
// Otherwise, leave it unbounded.
|
// Otherwise, leave it unbounded.
|
||||||
|
|
@ -350,6 +353,8 @@ class GenerateCandidates {
|
||||||
|
|
||||||
const size_t M_;
|
const size_t M_;
|
||||||
const size_t K_;
|
const size_t K_;
|
||||||
|
const size_t N_;
|
||||||
|
const size_t sizeof_TC_;
|
||||||
|
|
||||||
const size_t max_mr_;
|
const size_t max_mr_;
|
||||||
const size_t nr_;
|
const size_t nr_;
|
||||||
|
|
@ -365,23 +370,25 @@ class GenerateCandidates {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||||
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
|
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
|
||||||
size_t nr,
|
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||||
const IndexRangePartition& ranges_np,
|
const IndexRangePartition& ranges_np,
|
||||||
bool print_config) {
|
bool print_config) {
|
||||||
return GenerateCandidates(M, K, N, max_mr, nr, ranges_np, print_config)();
|
return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np,
|
||||||
|
print_config)();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||||
// memory accesses or false sharing, unless there are insufficient per-package
|
// memory accesses or false sharing, unless there are insufficient per-package
|
||||||
// rows for that.
|
// rows for that.
|
||||||
static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
|
static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
|
||||||
size_t np_multiple = Allocator::QuantumBytes() / sizeof(float);
|
size_t num_packages) {
|
||||||
|
size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC;
|
||||||
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
||||||
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
||||||
// choose a smaller multiple.
|
// choose a smaller multiple.
|
||||||
if (N % (np_multiple * num_packages)) {
|
if (N % (np_multiple * num_packages)) {
|
||||||
const size_t min_multiple = Allocator::LineBytes() / sizeof(float);
|
const size_t min_multiple = Allocator::LineBytes() / sizeof_TC;
|
||||||
np_multiple =
|
np_multiple =
|
||||||
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
|
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
|
||||||
if (HWY_UNLIKELY(np_multiple == 0)) {
|
if (HWY_UNLIKELY(np_multiple == 0)) {
|
||||||
|
|
@ -398,10 +405,10 @@ static size_t NPMultiple(size_t N, size_t nr, size_t num_packages) {
|
||||||
}
|
}
|
||||||
|
|
||||||
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
||||||
size_t nr) const {
|
size_t sizeof_TC, size_t nr) const {
|
||||||
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
|
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
|
||||||
return StaticPartition(IndexRange(0, N), num_packages,
|
return StaticPartition(IndexRange(0, N), num_packages,
|
||||||
NPMultiple(N, nr, num_packages));
|
NPMultiple(N, sizeof_TC, nr, num_packages));
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) {
|
MatMulEnv::MatMulEnv(NestedPools& pools) : parallel(pools), storage(parallel) {
|
||||||
|
|
|
||||||
62
ops/matmul.h
62
ops/matmul.h
|
|
@ -60,7 +60,7 @@ class MMParallel {
|
||||||
|
|
||||||
// Initial static partitioning of B rows across packages.
|
// Initial static partitioning of B rows across packages.
|
||||||
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
||||||
size_t nr) const;
|
size_t sizeof_TC, size_t nr) const;
|
||||||
|
|
||||||
// For `BindB` and `BindC`.
|
// For `BindB` and `BindC`.
|
||||||
size_t Node(size_t pkg_idx) const {
|
size_t Node(size_t pkg_idx) const {
|
||||||
|
|
@ -170,13 +170,13 @@ class MMParallel {
|
||||||
NestedPools& pools_;
|
NestedPools& pools_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T> // float for C, double for partial
|
template <typename TC> // BF16/float for C, double for partial
|
||||||
void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
|
void BindC(size_t M, const RowPtr<TC>& C, MMParallel& parallel) {
|
||||||
if (!Allocator::ShouldBind()) return;
|
if (!Allocator::ShouldBind()) return;
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), kNR);
|
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
|
||||||
const size_t quantum = Allocator::QuantumBytes() / sizeof(T);
|
const size_t quantum = Allocator::QuantumBytes() / sizeof(TC);
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||||
|
|
@ -185,7 +185,7 @@ void BindC(size_t M, const RowPtr<T>& C, MMParallel& parallel) {
|
||||||
// BindRowsToPackageNodes may not be page-aligned.
|
// BindRowsToPackageNodes may not be page-aligned.
|
||||||
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
||||||
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
||||||
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(T),
|
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
|
||||||
node);
|
node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -355,11 +355,11 @@ static inline const char* StringFromParA(MMParA par_a) {
|
||||||
class MMConfig {
|
class MMConfig {
|
||||||
public:
|
public:
|
||||||
MMConfig() = default; // for std::vector
|
MMConfig() = default; // for std::vector
|
||||||
// `mr` is the number of A rows per call to `MMKernel::LoopOverKC`.
|
// `mr` is the number of A rows per call to `MMKernel::LoopKC`.
|
||||||
// `MMOrder` is how to parallelize the outer loops.
|
// `MMOrder` is how to parallelize the outer loops.
|
||||||
// `MMOut` is how/whether to parallelize filling the C result.
|
// `MMOut` is how/whether to parallelize filling the C result.
|
||||||
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
|
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
|
||||||
MMConfig(size_t K, size_t mr, size_t mc, size_t kc, size_t nc,
|
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
|
||||||
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
|
size_t kc_multiple, size_t nc_multiple, MMOrder order, MMOut out,
|
||||||
int inner_tasks)
|
int inner_tasks)
|
||||||
: mr_(static_cast<uint32_t>(mr)),
|
: mr_(static_cast<uint32_t>(mr)),
|
||||||
|
|
@ -381,7 +381,7 @@ class MMConfig {
|
||||||
if (kc != K && (kc % kc_multiple) != 0) {
|
if (kc != K && (kc % kc_multiple) != 0) {
|
||||||
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
|
HWY_WARN("kc %zu not a multiple of kc_multiple %zu", kc, kc_multiple);
|
||||||
}
|
}
|
||||||
if (nc % nc_multiple != 0) {
|
if (nc != N && (nc % nc_multiple) != 0) {
|
||||||
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
|
HWY_WARN("nc %zu not a multiple of nc_multiple %zu", nc, nc_multiple);
|
||||||
}
|
}
|
||||||
HWY_DASSERT(StringFromOrder(order_) != nullptr);
|
HWY_DASSERT(StringFromOrder(order_) != nullptr);
|
||||||
|
|
@ -428,8 +428,8 @@ class MMConfig {
|
||||||
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N, size_t max_mr,
|
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
|
||||||
size_t nr,
|
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||||
const IndexRangePartition& ranges_np,
|
const IndexRangePartition& ranges_np,
|
||||||
bool print_config);
|
bool print_config);
|
||||||
|
|
||||||
|
|
@ -588,8 +588,9 @@ class MMKeys {
|
||||||
|
|
||||||
// Per-MatMul-shape state.
|
// Per-MatMul-shape state.
|
||||||
struct MMPerKey {
|
struct MMPerKey {
|
||||||
MMPerKey(size_t max_packages, size_t N, size_t nr, MMParallel& parallel)
|
MMPerKey(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr,
|
||||||
: ranges_np(parallel.RangesOfNP(max_packages, N, nr)) {}
|
MMParallel& parallel)
|
||||||
|
: ranges_np(parallel.RangesOfNP(max_packages, N, sizeof_TC, nr)) {}
|
||||||
|
|
||||||
// Only profile if enabled and the main autotuner finished (the par_a
|
// Only profile if enabled and the main autotuner finished (the par_a
|
||||||
// autotuner is per-package and we want to avoid synchronization).
|
// autotuner is per-package and we want to avoid synchronization).
|
||||||
|
|
@ -623,18 +624,16 @@ struct MatMulEnv {
|
||||||
std::vector<MMPerKey> per_key;
|
std::vector<MMPerKey> per_key;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Arguments to MatMul() that are independent of the A/B type.
|
// Arguments to MatMul() that are independent of the A/B/C types.
|
||||||
// Reduces register pressure compared to individual values/references.
|
// Reduces register pressure compared to individual values/references.
|
||||||
struct MMArgs {
|
struct MMArgs {
|
||||||
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale,
|
||||||
const float* HWY_RESTRICT add, const RowPtrD& partial,
|
const float* HWY_RESTRICT add, const RowPtrD& partial)
|
||||||
const RowPtrF& C)
|
|
||||||
: env(&env),
|
: env(&env),
|
||||||
per_key(&per_key),
|
per_key(&per_key),
|
||||||
scale(scale),
|
scale(scale),
|
||||||
add(add),
|
add(add),
|
||||||
partial(partial),
|
partial(partial) {}
|
||||||
C(C) {}
|
|
||||||
|
|
||||||
MatMulEnv* env;
|
MatMulEnv* env;
|
||||||
MMPerKey* per_key;
|
MMPerKey* per_key;
|
||||||
|
|
@ -643,7 +642,6 @@ struct MMArgs {
|
||||||
const float* HWY_RESTRICT add;
|
const float* HWY_RESTRICT add;
|
||||||
// Same size as C, threads write at false-sharing-free granularity.
|
// Same size as C, threads write at false-sharing-free granularity.
|
||||||
RowPtrD partial;
|
RowPtrD partial;
|
||||||
RowPtrF C;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
// Wrapper over hwy::Zone that is only enabled when autotuning finished.
|
||||||
|
|
@ -683,22 +681,22 @@ struct MMZone {
|
||||||
// `ofs` required for compressed T.
|
// `ofs` required for compressed T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ConstMat {
|
struct ConstMat {
|
||||||
ConstMat(const T* ptr, Extents2D extents, size_t ofs = 0)
|
ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0)
|
||||||
: ptr(ptr), extents(extents), ofs(ofs) {
|
: ptr(ptr), extents(extents), stride(stride), ofs(ofs) {
|
||||||
HWY_DASSERT(ptr != nullptr);
|
HWY_DASSERT(ptr != nullptr);
|
||||||
|
HWY_DASSERT(stride >= extents.cols);
|
||||||
}
|
}
|
||||||
// TODO: support stride for page alignment.
|
|
||||||
size_t Row(size_t r) const {
|
size_t Row(size_t r) const {
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
if (r >= extents.rows) {
|
if (r >= extents.rows) {
|
||||||
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
HWY_ABORT("ConstMat::Row %zu out of bounds %zu", r, extents.rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ofs + extents.cols * r;
|
return ofs + r * stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Extents2D& Extents() const { return extents; }
|
const Extents2D& Extents() const { return extents; }
|
||||||
size_t Stride() const { return extents.cols; }
|
size_t Stride() const { return stride; }
|
||||||
|
|
||||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
||||||
// subrange of the original rows starting at row 0.
|
// subrange of the original rows starting at row 0.
|
||||||
|
|
@ -709,6 +707,7 @@ struct ConstMat {
|
||||||
|
|
||||||
const T* HWY_RESTRICT ptr;
|
const T* HWY_RESTRICT ptr;
|
||||||
Extents2D extents;
|
Extents2D extents;
|
||||||
|
size_t stride;
|
||||||
|
|
||||||
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
// `scale` allows expanding the smaller range of `SfpStream` to the original
|
||||||
// values. MatFromWeights sets this from `MatPtr`.
|
// values. MatFromWeights sets this from `MatPtr`.
|
||||||
|
|
@ -721,9 +720,9 @@ struct ConstMat {
|
||||||
|
|
||||||
// For deducing T.
|
// For deducing T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride,
|
||||||
size_t ofs = 0) {
|
size_t ofs = 0) {
|
||||||
return ConstMat<T>(ptr, extents, ofs);
|
return ConstMat<T>(ptr, extents, stride, ofs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For A argument to MatMul (activations).
|
// For A argument to MatMul (activations).
|
||||||
|
|
@ -732,22 +731,25 @@ ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
||||||
const RowVectorBatch<T>& row_vectors) {
|
const RowVectorBatch<T>& row_vectors) {
|
||||||
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
|
||||||
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
|
||||||
Extents2D(batch_size, row_vectors.Cols()));
|
Extents2D(batch_size, row_vectors.Cols()),
|
||||||
|
row_vectors.Stride());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
||||||
ConstMat<T> mat = MakeConstMat(const_cast<T*>(m.data()), m.Extents(), ofs);
|
ConstMat<T> mat =
|
||||||
|
MakeConstMat(const_cast<T*>(m.data()), m.Extents(), m.Stride(), ofs);
|
||||||
mat.scale = m.scale();
|
mat.scale = m.scale();
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TB>
|
template <typename TB>
|
||||||
void BindB(size_t N, const ConstMat<TB>& B, MMParallel& parallel) {
|
void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
|
||||||
|
MMParallel& parallel) {
|
||||||
if (!Allocator::ShouldBind()) return;
|
if (!Allocator::ShouldBind()) return;
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(MMParallel::kMaxPackages, N, kNR);
|
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
|
||||||
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
|
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
||||||
|
|
||||||
// Generates inputs: deterministic, within max SfpStream range.
|
// Generates inputs: deterministic, within max SfpStream range.
|
||||||
template <typename MatT>
|
template <typename MatT>
|
||||||
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
auto mat =
|
auto mat =
|
||||||
|
|
@ -112,12 +112,12 @@ MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||||
double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
|
||||||
double max_row_abs_sum = 0.0;
|
double max_row_abs_sum = 0.0;
|
||||||
for (size_t r = 0; r < extents.rows; r++) {
|
for (size_t r = 0; r < a.BatchSize(); r++) {
|
||||||
const float* row = a + r * extents.cols;
|
const float* row = a.Batch(r);
|
||||||
double row_abs_sum = 0.0;
|
double row_abs_sum = 0.0;
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < a.Cols(); c++) {
|
||||||
row_abs_sum += hwy::ScalarAbs(row[c]);
|
row_abs_sum += hwy::ScalarAbs(row[c]);
|
||||||
}
|
}
|
||||||
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
|
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
|
||||||
|
|
@ -126,41 +126,52 @@ double MaxRowAbsSum(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the maximum absolute value of `a`.
|
// Returns the maximum absolute value of `a`.
|
||||||
float MaxAbs(const float* HWY_RESTRICT a, const Extents2D& extents) {
|
float MaxAbs(const RowVectorBatch<float>& a) {
|
||||||
float max_abs = 0.0f;
|
float max_abs = 0.0f;
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
for (size_t c = 0; c < a.Cols(); c++) {
|
||||||
for (size_t r = 0; r < extents.rows; r++) {
|
for (size_t r = 0; r < a.BatchSize(); r++) {
|
||||||
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(a[r * extents.cols + c]));
|
const float* row = a.Batch(r);
|
||||||
|
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return max_abs;
|
return max_abs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// B is already transposed.
|
// B is already transposed.
|
||||||
template <typename TA, typename TB>
|
template <typename TA, typename TB, typename TC>
|
||||||
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
const RowPtrF& C_slow, const RowPtrF& C) {
|
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t num_a = A.extents.Area();
|
const size_t cols = A.extents.cols;
|
||||||
const size_t num_b = B.extents.Area();
|
const size_t B_rows = B.extents.rows;
|
||||||
const size_t N = hn::Lanes(df);
|
|
||||||
// Round up for DecompressAndZeroPad.
|
// Round up for DecompressAndZeroPad.
|
||||||
FloatPtr a = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_a, N));
|
RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
|
||||||
FloatPtr b_trans = hwy::AllocateAligned<float>(hwy::RoundUpTo(num_b, N));
|
RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
|
||||||
HWY_ASSERT(a && b_trans);
|
RowVectorBatch<float> c_batch =
|
||||||
|
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
|
||||||
|
RowVectorBatch<float> c_slow_batch =
|
||||||
|
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
|
||||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||||
DecompressAndZeroPad(df, MakeSpan(A.ptr, num_a), 0, a.get(), num_a);
|
for (size_t m = 0; m < A.extents.rows; ++m) {
|
||||||
DecompressAndZeroPad(df, MakeSpan(B.ptr, num_b), 0, b_trans.get(), num_b);
|
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
|
||||||
|
a_batch.Batch(m), cols);
|
||||||
|
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
|
||||||
|
B_rows);
|
||||||
|
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
|
||||||
|
c_slow_batch.Batch(m), B_rows);
|
||||||
|
}
|
||||||
|
for (size_t n = 0; n < B_rows; ++n) {
|
||||||
|
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0,
|
||||||
|
b_trans_batch.Batch(n), cols);
|
||||||
|
}
|
||||||
|
|
||||||
// MatMul rounds inputs to BF16, so error is proportional to the max input
|
// MatMul rounds inputs to BF16, so error is proportional to the max input
|
||||||
// magnitude, but also to f32 accumulation of rows in A and B.
|
// magnitude, but also to f32 accumulation of rows in A and B.
|
||||||
const double norm = MaxRowAbsSum(a.get(), A.Extents()) *
|
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
|
||||||
MaxRowAbsSum(b_trans.get(), B.Extents());
|
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
|
||||||
const float max_abs =
|
|
||||||
MaxAbs(a.get(), A.Extents()) * MaxAbs(b_trans.get(), B.Extents());
|
|
||||||
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
|
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
|
||||||
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||||
double tolerance = 10 * norm * eps_f32;
|
double tolerance = 12 * norm * eps_f32;
|
||||||
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
|
||||||
// tolerance there.
|
// tolerance there.
|
||||||
if (IsF32<TA>() && IsF32<TB>()) {
|
if (IsF32<TA>() && IsF32<TB>()) {
|
||||||
|
|
@ -169,30 +180,38 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
if (tolerance > 500.0) {
|
if (tolerance > 500.0) {
|
||||||
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
|
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
|
||||||
}
|
}
|
||||||
|
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
|
||||||
|
|
||||||
for (size_t r = 0; r < A.extents.rows; r++) {
|
for (size_t r = 0; r < A.extents.rows; r++) {
|
||||||
const float* expected_row = C_slow.Row(r);
|
const float* expected_row = c_slow_batch.Batch(r);
|
||||||
const float* actual_row = C.Row(r);
|
const float* actual_row = c_batch.Batch(r);
|
||||||
for (size_t c = 0; c < B.extents.rows; c++) {
|
for (size_t c = 0; c < B.extents.rows; c++) {
|
||||||
const double expected_value = static_cast<double>(expected_row[c]);
|
const double expected_value = static_cast<double>(expected_row[c]);
|
||||||
const double actual_value = static_cast<double>(actual_row[c]);
|
const double actual_value = static_cast<double>(actual_row[c]);
|
||||||
|
const bool in_range = expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance;
|
||||||
|
|
||||||
if (!(expected_value - tolerance <= actual_value &&
|
if (!in_range) {
|
||||||
actual_value <= expected_value + tolerance)) {
|
const double max = HWY_MAX(expected_value, actual_value);
|
||||||
HWY_ABORT(
|
const double min = HWY_MIN(expected_value, actual_value);
|
||||||
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
const double rel = max / HWY_MAX(min, 1E-6);
|
||||||
"tolerance %f\n",
|
if (rel > max_rel) {
|
||||||
r, c, expected_value, actual_value, norm, max_abs, tolerance);
|
hwy::Abort(__FILE__, line,
|
||||||
|
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
||||||
|
"tolerance %f rel %E max_rel %E\n",
|
||||||
|
r, c, expected_value, actual_value, norm, max_abs,
|
||||||
|
tolerance, rel, max_rel);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// B is already transposed.
|
// B is already transposed.
|
||||||
template <typename TA, typename TB>
|
template <typename TA, typename TB, typename TC>
|
||||||
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||||
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
const float* HWY_RESTRICT add_row, MatMulEnv& env,
|
||||||
const RowPtrF& C) {
|
const RowPtr<TC>& C) {
|
||||||
// TA can be any Packed except NuqStream because it uses pointer
|
// TA can be any Packed except NuqStream because it uses pointer
|
||||||
// arithmetic, because it is the second argument to Dot, which does not
|
// arithmetic, because it is the second argument to Dot, which does not
|
||||||
// support a v_ofs.
|
// support a v_ofs.
|
||||||
|
|
@ -200,7 +219,8 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||||
const float scale = A.scale * B.scale;
|
const float scale = A.scale * B.scale;
|
||||||
|
|
||||||
const hn::ScalableTag<float> df; // lane type is ignored
|
const hn::ScalableTag<float> df; // lane type is ignored
|
||||||
const PackedSpan<const TB> b_span = MakeSpan(B.ptr, B.ofs + B.extents.Area());
|
const PackedSpan<const TB> b_span =
|
||||||
|
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
|
||||||
const IndexRange all_rows_c(0, A.Extents().rows);
|
const IndexRange all_rows_c(0, A.Extents().rows);
|
||||||
const IndexRange all_cols_c(0, C.Cols());
|
const IndexRange all_cols_c(0, C.Cols());
|
||||||
|
|
||||||
|
|
@ -219,12 +239,12 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||||
get_col_c, all_clusters,
|
get_col_c, all_clusters,
|
||||||
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
|
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
|
||||||
for (size_t r : rows_c) {
|
for (size_t r : rows_c) {
|
||||||
float* HWY_RESTRICT C_row = C.Row(r);
|
TC* HWY_RESTRICT C_row = C.Row(r);
|
||||||
for (size_t c : cols_c) {
|
for (size_t c : cols_c) {
|
||||||
const float add = add_row ? add_row[c] : 0.0f;
|
const float add = add_row ? add_row[c] : 0.0f;
|
||||||
C_row[c] =
|
C_row[c] = hwy::ConvertScalarTo<TC>(
|
||||||
add + scale * Dot(df, b_span, c * B.extents.cols,
|
add + scale * Dot(df, b_span, c * B.Stride(),
|
||||||
A.ptr + A.Row(r), A.extents.cols);
|
A.ptr + A.Row(r), A.extents.cols));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -239,14 +259,15 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
||||||
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TA, typename TB = TA>
|
template <typename TA, typename TB = TA, typename TC = float>
|
||||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env, int line) {
|
||||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
||||||
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s\n", rows_ac,
|
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
||||||
cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>());
|
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
||||||
|
TypeName<TC>());
|
||||||
|
|
||||||
env.print_config = true;
|
env.print_config = false; // Too verbose.
|
||||||
env.print_best = true;
|
env.print_best = true;
|
||||||
|
|
||||||
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
const Extents2D A_extents(rows_ac, cols_a_rows_b);
|
||||||
|
|
@ -255,8 +276,8 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
|
|
||||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||||
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
|
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
|
||||||
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
|
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
|
||||||
HWY_ASSERT(a && b_trans);
|
HWY_ASSERT(a && b_trans);
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||||
|
|
@ -269,14 +290,14 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
const auto A = ConstMatFromWeights(*a);
|
const auto A = ConstMatFromWeights(*a);
|
||||||
const auto B = ConstMatFromWeights(*b_trans);
|
const auto B = ConstMatFromWeights(*b_trans);
|
||||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||||
const RowPtrF C_slow = RowPtrFromBatch(c_slow_batch);
|
const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
|
||||||
const RowPtrF C = RowPtrFromBatch(c_batch);
|
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
|
||||||
|
|
||||||
MatMulSlow(A, B, add_row, env, C_slow);
|
MatMulSlow(A, B, add_row, env, C_slow);
|
||||||
// A few reps to get coverage of the various autotuned code paths.
|
// A few reps to get coverage of the various autotuned code paths.
|
||||||
for (size_t rep = 0; rep < 16; ++rep) {
|
for (size_t rep = 0; rep < 16; ++rep) {
|
||||||
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
|
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
|
||||||
AssertClose(A, B, C_slow, C);
|
AssertClose(A, B, C_slow, C, line);
|
||||||
if (per_key->autotune.Best()) break;
|
if (per_key->autotune.Best()) break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -311,7 +332,7 @@ void TestTiny() {
|
||||||
for (size_t M = 1; M <= 12; ++M) {
|
for (size_t M = 1; M <= 12; ++M) {
|
||||||
for (size_t K = 1; K <= 64; K *= 2) {
|
for (size_t K = 1; K <= 64; K *= 2) {
|
||||||
for (size_t N = 4; N <= 64; N += max_packages * 4) {
|
for (size_t N = 4; N <= 64; N += max_packages * 4) {
|
||||||
TestMatMul<F32, F32>(M, K, N, /*add=*/false, env);
|
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -332,56 +353,69 @@ void TestAllMatMul() {
|
||||||
Allocator::Init(pools.Topology(), /*enable_bind=*/true);
|
Allocator::Init(pools.Topology(), /*enable_bind=*/true);
|
||||||
MatMulEnv env(pools);
|
MatMulEnv env(pools);
|
||||||
|
|
||||||
// Sizes seen in gemma_test 2B.
|
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
|
||||||
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env);
|
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env);
|
// TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env);
|
// TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env);
|
// TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env);
|
// TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(5, 2048, 512, /*add=*/false, env);
|
// TestMatMul<F32>(5, 2048, 512, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env);
|
// TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env);
|
// TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env);
|
// TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env, __LINE__);
|
||||||
|
|
||||||
// medium-sized square
|
// medium-sized square, f32 vs bf16 for A, B, C; plus add.
|
||||||
TestMatMul<F32>(512, 512, 512, /*add=*/false, env);
|
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16>(512, 512, 512, /*add=*/true, env);
|
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32, BF16>(512, 512, 512, /*add=*/false, env);
|
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, F32>(512, 512, 512, /*add=*/true, env);
|
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(512, 512, 512, /*add=*/false, env);
|
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(512, 512, 512, /*add=*/true, env);
|
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
|
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
|
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
|
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
|
||||||
|
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
|
||||||
|
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
|
||||||
|
|
||||||
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
// minimal non-square test. kColsARowsB must be at least 2 vectors.
|
||||||
TestMatMul<F32>(35, 128, 32, /*add=*/false, env);
|
TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env);
|
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env);
|
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32>(4, 128, 32, /*add=*/true, env);
|
TestMatMul<F32>(4, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env);
|
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env);
|
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env);
|
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env);
|
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env);
|
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(3, 128, 32, /*add=*/false, env);
|
TestMatMul<F32>(3, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env);
|
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env);
|
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32>(2, 128, 64, /*add=*/true, env);
|
TestMatMul<F32>(2, 128, 64, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env);
|
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env);
|
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env);
|
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env);
|
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env);
|
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<F32>(1, 128, 32, /*add=*/false, env);
|
TestMatMul<F32>(1, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env);
|
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env);
|
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env);
|
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue