mirror of https://github.com/google/gemma.cpp.git
Simplify MatMul: remove F32 special case (build time)
Also move kMaxM into separate kMaxBatchSize PiperOrigin-RevId: 802086590
This commit is contained in:
parent
1e3c853e80
commit
b7b3d353db
|
|
@ -248,17 +248,17 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
runtime_config.max_generated_tokens = max_generated_tokens;
|
||||
runtime_config.prefill_tbatch_size = prefill_tbatch_size;
|
||||
runtime_config.decode_qbatch_size = decode_qbatch_size;
|
||||
if (prefill_tbatch_size > MMStorage::kMaxM) {
|
||||
if (prefill_tbatch_size > kMaxBatchSize) {
|
||||
HWY_ABORT(
|
||||
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
prefill_tbatch_size, MMStorage::kMaxM);
|
||||
"prefill_tbatch_size %zu > kMaxBatchSize %zu: specify a "
|
||||
"smaller value, or increase kMaxBatchSize.\n",
|
||||
prefill_tbatch_size, kMaxBatchSize);
|
||||
}
|
||||
if (decode_qbatch_size > MMStorage::kMaxM) {
|
||||
if (decode_qbatch_size > kMaxBatchSize) {
|
||||
HWY_ABORT(
|
||||
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, "
|
||||
"or increase the constant in MMStorage.\n",
|
||||
decode_qbatch_size, MMStorage::kMaxM);
|
||||
"decode_qbatch_size %zu > kMaxBatchSize %zu: specify a "
|
||||
"smaller value, or increase kMaxBatchSize.\n",
|
||||
decode_qbatch_size, kMaxBatchSize);
|
||||
}
|
||||
|
||||
runtime_config.temperature = temperature;
|
||||
|
|
|
|||
298
ops/matmul-inl.h
298
ops/matmul-inl.h
|
|
@ -96,28 +96,27 @@ struct MMSetC {};
|
|||
struct MMAddC {};
|
||||
|
||||
// Stores horizontal sums of up to 16 vectors via transpose.
|
||||
template <size_t kRowsAC, bool kAdd>
|
||||
template <size_t kRowsAC>
|
||||
class MMStoreHorizontalSumsIntoC {
|
||||
public:
|
||||
static_assert(kNR == 4); // for `StoreInterleaved4`
|
||||
|
||||
// Computes horizontal sums of `kRowsAC x kNR` vectors and stores into
|
||||
// `C` starting at `(row_c, col_c)`.
|
||||
//
|
||||
// Given 16 (`kRowsAC x kNR`) full vectors of 32-bit float, returns four
|
||||
// 4-wide float vectors with their horizontal sums.
|
||||
// `Crc` are the 16 combinations of an A row vector indexed by `r`, times a
|
||||
// transposed B row vector indexed by `c`. Their elements are thus a subset
|
||||
// of the terms of the dot product constituting the final `C[r, c]` result.
|
||||
// Thus we compute the horizontal sums of each `Crc`. The elements may be
|
||||
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
|
||||
// this does not change their horizontal sum.
|
||||
template <class DF, class VF = hn::Vec<DF>, typename TC>
|
||||
HWY_INLINE void operator()(DF df, //
|
||||
template <class DF, class VF = hn::Vec<DF>, class D4 = hn::Full128<float>,
|
||||
class V4 = hn::Vec<D4>>
|
||||
HWY_INLINE void Reduce4x4(DF df, //
|
||||
VF C00, VF C01, VF C02, VF C03, //
|
||||
VF C10, VF C11, VF C12, VF C13, //
|
||||
VF C20, VF C21, VF C22, VF C23, //
|
||||
VF C30, VF C31, VF C32, VF C33, //
|
||||
const size_t row_c, const size_t col_c,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) const {
|
||||
V4& sum0, V4& sum1, V4& sum2, V4& sum3) {
|
||||
HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
|
||||
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df);
|
||||
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing
|
||||
|
|
@ -133,14 +132,13 @@ class MMStoreHorizontalSumsIntoC {
|
|||
// Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in
|
||||
// the elements of one V4. We have four independent rows `r`, hence the
|
||||
// code is effectively unrolled, which increases throughput.
|
||||
const hn::CappedTag<float, kNR> d4;
|
||||
using V4 = hn::Vec<decltype(d4)>;
|
||||
// Store to four elements per row of `partial`.
|
||||
// No loop is required because vectors are at least 4*32 bits.
|
||||
V4 sum0 = MaybeLoad<0>(d4, N, buf);
|
||||
V4 sum1 = MaybeLoad<1>(d4, N, buf);
|
||||
V4 sum2 = MaybeLoad<2>(d4, N, buf);
|
||||
V4 sum3 = MaybeLoad<3>(d4, N, buf);
|
||||
const D4 d4;
|
||||
sum0 = MaybeLoad<0>(d4, N, buf);
|
||||
sum1 = MaybeLoad<1>(d4, N, buf);
|
||||
sum2 = MaybeLoad<2>(d4, N, buf);
|
||||
sum3 = MaybeLoad<3>(d4, N, buf);
|
||||
|
||||
for (size_t lane = 1; lane < N; ++lane) {
|
||||
sum0 = MaybeAdd<0>(d4, N, sum0, buf + kNR * lane);
|
||||
|
|
@ -148,13 +146,23 @@ class MMStoreHorizontalSumsIntoC {
|
|||
sum2 = MaybeAdd<2>(d4, N, sum2, buf + kNR * lane);
|
||||
sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane);
|
||||
}
|
||||
}
|
||||
|
||||
// Scales the dot-product terms and adds bias (if present) and stores the
|
||||
// four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is
|
||||
// `MMSetC`, the vectors are written as-is (first call, or small K).
|
||||
// Otherwise, they are partial sums and are accumulated into C.
|
||||
template <class D4, class V4 = hn::Vec<D4>, class Tag, typename TC>
|
||||
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag,
|
||||
const size_t row_c, const size_t col_c,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) const {
|
||||
const V4 vscale = hn::Set(d4, args.scale);
|
||||
HWY_ALIGN static constexpr float kZero[4] = {};
|
||||
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero);
|
||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c);
|
||||
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -191,18 +199,20 @@ class MMStoreHorizontalSumsIntoC {
|
|||
}
|
||||
|
||||
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
|
||||
typename TC>
|
||||
class Tag, typename TC>
|
||||
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
|
||||
VF4 vadd, RowPtrs<TC> C_rows,
|
||||
VF4 vadd, Tag, RowPtrs<TC> C_rows,
|
||||
const size_t row_c,
|
||||
const size_t col_c) {
|
||||
if constexpr (kRow < kRowsAC) {
|
||||
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
|
||||
const hn::Rebind<TC, DF4> dc4;
|
||||
if constexpr (kAdd) {
|
||||
if constexpr (hwy::IsSame<Tag, MMAddC>()) {
|
||||
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
|
||||
} // else: add bias (only once, the first time we store to C)
|
||||
|
||||
} else {
|
||||
static_assert(hwy::IsSame<Tag, MMSetC>());
|
||||
// vadd remains the bias (added once, the first time we store to C)
|
||||
}
|
||||
const VF4 out = hn::MulAdd(sum, vscale, vadd);
|
||||
hn::Store(TCFromF32(dc4, out), dc4, pos);
|
||||
}
|
||||
|
|
@ -215,9 +225,9 @@ class MMKernel {
|
|||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
||||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
template <class Tag, typename TA, typename TC>
|
||||
static HWY_INLINE void A2C0(const StridedView<TA> A_view,
|
||||
const StridedViewBF& B_view, size_t mr,
|
||||
template <class Tag, typename TC>
|
||||
static HWY_INLINE void A2C0(const StridedViewBF A_view,
|
||||
const StridedViewBF B_view, size_t mr,
|
||||
const IndexRange& range_mc, const size_t row_b,
|
||||
size_t kc, Tag tag, const MMArgs& args,
|
||||
RowPtrs<TC> C_rows) {
|
||||
|
|
@ -357,34 +367,18 @@ class MMKernel {
|
|||
}
|
||||
}
|
||||
|
||||
// For A=F32, B=BF16 without native BF16 dot product: one lane-crossing
|
||||
// promotion is likely cheaper than AND+SHIFT for promoting odd/even BF.
|
||||
// Caller already promoted B, so all inputs are F32.
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2,
|
||||
VF b3, VF& C0, VF& C1, VF& C2,
|
||||
VF& C3) {
|
||||
HWY_DASSERT(!HWY_NATIVE_DOT_BF16);
|
||||
C0 = hn::MulAdd(a, b0, C0);
|
||||
C1 = hn::MulAdd(a, b1, C1);
|
||||
C2 = hn::MulAdd(a, b2, C2);
|
||||
C3 = hn::MulAdd(a, b3, C3);
|
||||
}
|
||||
|
||||
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
||||
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||
// Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`.
|
||||
// `B` is BF16, `A` and `C` can be F32 or BF16.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC>
|
||||
static HWY_INLINE void LoopKC(const StridedView<TA> A_view,
|
||||
const StridedViewBF& B_view, size_t row_ac,
|
||||
// `A` and `B` are always BF16, `C` can be F32 or BF16.
|
||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
||||
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
||||
const StridedViewBF B_view, size_t row_ac,
|
||||
size_t imc, size_t col_c, size_t kc, Tag tag,
|
||||
const MMArgs& args, RowPtrs<TC> C_rows) {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
|
||||
HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag<TA>());
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
HWY_DASSERT(kRowsAC <= kMaxMR);
|
||||
|
|
@ -393,10 +387,10 @@ class MMKernel {
|
|||
|
||||
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
|
||||
static_assert(kNR == 4);
|
||||
const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
|
||||
const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
|
||||
const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
|
||||
const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
|
||||
const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
|
||||
const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
|
||||
const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
|
||||
const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
|
||||
const BF16* HWY_RESTRICT br0 = B_view.Row(0);
|
||||
const BF16* HWY_RESTRICT br1 = B_view.Row(1);
|
||||
const BF16* HWY_RESTRICT br2 = B_view.Row(2);
|
||||
|
|
@ -416,8 +410,6 @@ class MMKernel {
|
|||
C33 = hn::Zero(df);
|
||||
|
||||
size_t ikc = 0;
|
||||
// The loop step is always NBF: for non-native BF16 with TA=F32, this
|
||||
// entails 2x unrolling, which helps a little.
|
||||
const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
|
||||
if (kc >= kc_step) {
|
||||
HWY_UNROLL(1)
|
||||
|
|
@ -432,10 +424,6 @@ class MMKernel {
|
|||
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
|
||||
|
||||
// Should only get here if `A` is BF16, otherwise `DecompressA` would
|
||||
// convert to BF16 and `A_view` points to that.
|
||||
HWY_DASSERT(IsBF16<TA>());
|
||||
|
||||
{
|
||||
const VBF a0 = hn::Load(dbf, ar0 + ikc);
|
||||
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
|
||||
|
|
@ -457,7 +445,6 @@ class MMKernel {
|
|||
C33);
|
||||
}
|
||||
} else { // !HWY_NATIVE_DOT_BF16
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// When both are BF16, it is better to load promote odd/even,
|
||||
// because lane-crossing promotion for both might be bottlenecked on
|
||||
// shuffles.
|
||||
|
|
@ -483,77 +470,16 @@ class MMKernel {
|
|||
const VBF a0 = hn::Load(dbf, ar0 + ikc);
|
||||
const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
|
||||
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C00, C01, C02, C03, C10, C11,
|
||||
C12, C13);
|
||||
b3o, b3e, C00, C01, C02, C03, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::Load(dbf, ar2 + ikc);
|
||||
const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
|
||||
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C20, C21, C22, C23, C30, C31,
|
||||
C32, C33);
|
||||
}
|
||||
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
|
||||
// Full-vector loads are a bit faster on SKX than half + PromoteTo.
|
||||
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
|
||||
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
|
||||
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
|
||||
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
|
||||
const VF b00 = hn::PromoteLowerTo(df, b0);
|
||||
const VF b10 = hn::PromoteLowerTo(df, b1);
|
||||
const VF b20 = hn::PromoteLowerTo(df, b2);
|
||||
const VF b30 = hn::PromoteLowerTo(df, b3);
|
||||
const VF b01 = hn::PromoteUpperTo(df, b0);
|
||||
const VF b11 = hn::PromoteUpperTo(df, b1);
|
||||
const VF b21 = hn::PromoteUpperTo(df, b2);
|
||||
const VF b31 = hn::PromoteUpperTo(df, b3);
|
||||
|
||||
{
|
||||
const VF a00 = hn::Load(df, ar0 + ikc);
|
||||
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a10 = hn::Load(df, ar1 + ikc);
|
||||
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
// C00 is ready again. On SKX, this interleaved unrolling is faster
|
||||
// than consuming all `b*1` at the end of the loop.
|
||||
{
|
||||
const VF a01 = hn::Load(df, ar0 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a11 = hn::Load(df, ar1 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a20 = hn::Load(df, ar2 + ikc);
|
||||
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a30 = hn::Load(df, ar3 + ikc);
|
||||
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
|
||||
b3o, b3e, C20, C21, C22, C23, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a21 = hn::Load(df, ar2 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a31 = hn::Load(df, ar3 + ikc + NA);
|
||||
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -569,10 +495,6 @@ class MMKernel {
|
|||
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
|
||||
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
|
||||
|
||||
// Should only get here if `A` is BF16, otherwise `DecompressA` would
|
||||
// convert to BF16 and `A_view` points to that.
|
||||
HWY_DASSERT(IsBF16<TA>());
|
||||
|
||||
{
|
||||
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
|
||||
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
|
||||
|
|
@ -594,7 +516,6 @@ class MMKernel {
|
|||
C33);
|
||||
}
|
||||
} else { // !HWY_NATIVE_DOT_BF16
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// When both are BF16, it is better to load promote odd/even, because
|
||||
// lane-crossing promotion for both might be bottlenecked on shuffles.
|
||||
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
|
||||
|
|
@ -619,79 +540,15 @@ class MMKernel {
|
|||
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
|
||||
const VBF a1 =
|
||||
kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0;
|
||||
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C00, C01, C02, C03, C10, C11, C12,
|
||||
C13);
|
||||
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
|
||||
b3e, C00, C01, C02, C03, C10, C11, C12, C13);
|
||||
}
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
|
||||
const VBF a3 =
|
||||
kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2;
|
||||
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
|
||||
b3o, b3e, C20, C21, C22, C23, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
} else { // IsF32<TA>(): promote half-B to F32, F32*F32.
|
||||
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
|
||||
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
|
||||
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
|
||||
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
|
||||
const VF b00 = hn::PromoteLowerTo(df, b0);
|
||||
const VF b10 = hn::PromoteLowerTo(df, b1);
|
||||
const VF b20 = hn::PromoteLowerTo(df, b2);
|
||||
const VF b30 = hn::PromoteLowerTo(df, b3);
|
||||
const VF b01 = hn::PromoteUpperTo(df, b0);
|
||||
const VF b11 = hn::PromoteUpperTo(df, b1);
|
||||
const VF b21 = hn::PromoteUpperTo(df, b2);
|
||||
const VF b31 = hn::PromoteUpperTo(df, b3);
|
||||
|
||||
const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA;
|
||||
|
||||
{
|
||||
const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
// C00 is ready again. On SKX, this interleaved unrolling is faster
|
||||
// than consuming all `b*1` at the end of the loop.
|
||||
{
|
||||
const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
|
||||
C03);
|
||||
}
|
||||
if constexpr (kRowsAC > 1) {
|
||||
const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
|
||||
C13);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc);
|
||||
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
|
||||
if constexpr (kRowsAC > 2) {
|
||||
const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
|
||||
C23);
|
||||
}
|
||||
if constexpr (kRowsAC > 3) {
|
||||
const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2);
|
||||
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
|
||||
C33);
|
||||
}
|
||||
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
|
||||
b3e, C20, C21, C22, C23, C30, C31, C32, C33);
|
||||
}
|
||||
}
|
||||
} // remaining_kc != 0
|
||||
|
|
@ -699,16 +556,12 @@ class MMKernel {
|
|||
// This is a substantial fraction (about 1/3) of the total time, but is
|
||||
// called frequently, so do not add a profiler zone.
|
||||
|
||||
if constexpr (hwy::IsSame<Tag, MMSetC>()) {
|
||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
|
||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||
C31, C32, C33, row_ac, col_c, args, C_rows);
|
||||
} else {
|
||||
static_assert(hwy::IsSame<Tag, MMAddC>());
|
||||
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
|
||||
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
|
||||
C31, C32, C33, row_ac, col_c, args, C_rows);
|
||||
}
|
||||
MMStoreHorizontalSumsIntoC<kRowsAC> horz;
|
||||
const hn::Full128<float> d4;
|
||||
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
|
||||
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
|
||||
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
|
||||
horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -717,15 +570,6 @@ class MMKernel {
|
|||
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC.
|
||||
// Its member variables avoid long argument lists in Do*().
|
||||
class MMPerPackage {
|
||||
// Decompression is only required for F32 A and native BF16 dot products.
|
||||
// If A is already BF16, we can use a view. Padding is not required
|
||||
// because `LoopKC` can handle non-vector multiples. `LoopKC` also contains
|
||||
// a special case for F32 `A` and non-native BF16 dot products.
|
||||
template <typename TA>
|
||||
static constexpr bool WantDecompressA() {
|
||||
return HWY_NATIVE_DOT_BF16 && IsF32<TA>();
|
||||
}
|
||||
|
||||
public:
|
||||
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
|
||||
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np)
|
||||
|
|
@ -741,13 +585,6 @@ class MMPerPackage {
|
|||
inner_tasks_(config.InnerTasks()),
|
||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {}
|
||||
|
||||
// The size of `A` that will actually be used, for purposes of choosing the
|
||||
// autotuning candidates. Keep in sync with the `operator()` logic below.
|
||||
template <typename TA>
|
||||
static constexpr size_t ABytes() {
|
||||
return WantDecompressA<TA>() ? sizeof(BF16) : sizeof(TA);
|
||||
}
|
||||
|
||||
// B and maybe A are decompressed several call layers lower, but not all
|
||||
// member functions depend on TA/TB, so pass them as an argument instead of
|
||||
// templating the class.
|
||||
|
|
@ -755,12 +592,16 @@ class MMPerPackage {
|
|||
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
|
||||
const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
if constexpr (WantDecompressA<TA>()) {
|
||||
if constexpr (IsBF16<TA>()) {
|
||||
// We can use a view, regardless of columns/padding, because `LoopKC`
|
||||
// supports non-vector multiples.
|
||||
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
|
||||
} else {
|
||||
// Always decompress. To reduce code size/compile time, we no longer
|
||||
// support a separate F32 kernel; most A are already BF16.
|
||||
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
|
||||
DecompressA<MMParallelPolicyT>(A, A_view);
|
||||
DispatchOrder(parallel_policy, A_view, B, C_rows);
|
||||
} else {
|
||||
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -937,7 +778,7 @@ class MMPerPackage {
|
|||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||
const auto loop_nc = [&](const StridedViewBF& B_storage_view,
|
||||
const auto loop_nc = [&](const StridedViewBF B_storage_view,
|
||||
const IndexRange& range_mc,
|
||||
const IndexRange& range_kc,
|
||||
const IndexRange& range_nc,
|
||||
|
|
@ -1080,7 +921,7 @@ class MMPerPackage {
|
|||
template <typename TB>
|
||||
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
|
||||
const IndexRange& range_kc,
|
||||
const StridedViewBF& B_view) const {
|
||||
const StridedViewBF B_view) const {
|
||||
const hn::ScalableTag<BF16> dbf;
|
||||
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
|
||||
|
||||
|
|
@ -1229,7 +1070,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
if (HWY_UNLIKELY(!tuner.HasCandidates())) {
|
||||
// Ensure matrix dimensions match each other.
|
||||
HWY_ASSERT(K == B.Cols());
|
||||
HWY_ASSERT(M <= MMStorage::kMaxM);
|
||||
HWY_ASSERT(M <= kMaxBatchSize);
|
||||
HWY_ASSERT(K <= MMStorage::kMaxK);
|
||||
HWY_ASSERT(N % kNR == 0);
|
||||
// Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are
|
||||
|
|
@ -1241,9 +1082,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes()));
|
||||
}
|
||||
|
||||
tuner.SetCandidates(
|
||||
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC),
|
||||
kMaxMR, kNR, per_key.ranges_np, env.print_config));
|
||||
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
|
||||
kNR, per_key.ranges_np, env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
|
|
|
|||
|
|
@ -64,21 +64,19 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
|||
class GenerateCandidates {
|
||||
public:
|
||||
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TA, size_t sizeof_TC, size_t max_mr,
|
||||
size_t nr, const IndexRangePartition& ranges_np,
|
||||
bool print_config)
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np, bool print_config)
|
||||
: allocator_(allocator),
|
||||
M_(M),
|
||||
K_(K),
|
||||
N_(N),
|
||||
sizeof_TA_(sizeof_TA),
|
||||
sizeof_TC_(sizeof_TC),
|
||||
max_mr_(max_mr),
|
||||
nr_(nr),
|
||||
// These influence kc/nc, but are also stored in `MMConfig` for
|
||||
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||
// is likely still in L1, but we expect K > 1000 and might as well round
|
||||
// up to the line size. Use BF16, not sizeof_TA, because B is BF16.
|
||||
// up to the line size. Both A and B are BF16.
|
||||
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
|
||||
nc_multiple_(allocator.StepBytes() / sizeof_TC),
|
||||
ranges_np_(ranges_np),
|
||||
|
|
@ -176,8 +174,8 @@ class GenerateCandidates {
|
|||
// size. This results in an overestimate, and the loop below will propose
|
||||
// the next few smaller values for the autotuner to evaluate.
|
||||
const size_t bytes_ab =
|
||||
allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream));
|
||||
const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16);
|
||||
allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream));
|
||||
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
|
||||
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
|
||||
kc_max =
|
||||
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
|
||||
|
|
@ -224,9 +222,9 @@ class GenerateCandidates {
|
|||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
||||
// partial.
|
||||
const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes();
|
||||
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
|
||||
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
|
||||
mc_max = HWY_MIN(mc_max, kMaxBatchSize);
|
||||
HWY_DASSERT(mc_max != 0);
|
||||
mc_max = HWY_MIN(mc_max, M_);
|
||||
mc_max = hwy::RoundDownTo(mc_max, mr);
|
||||
|
|
@ -340,7 +338,6 @@ class GenerateCandidates {
|
|||
const size_t M_;
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
const size_t sizeof_TA_;
|
||||
const size_t sizeof_TC_;
|
||||
|
||||
const size_t max_mr_;
|
||||
|
|
@ -358,12 +355,12 @@ class GenerateCandidates {
|
|||
|
||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TA,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
bool print_config) {
|
||||
return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr,
|
||||
nr, ranges_np, print_config)();
|
||||
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
|
||||
ranges_np, print_config)();
|
||||
}
|
||||
|
||||
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||
|
|
@ -409,7 +406,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
|
|||
char cpu100[100];
|
||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C
|
||||
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
|
||||
}
|
||||
|
||||
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
|
||||
|
|
|
|||
18
ops/matmul.h
18
ops/matmul.h
|
|
@ -330,10 +330,8 @@ using StridedViewD = StridedView<double>;
|
|||
// Per-package storage for packed A.
|
||||
class MMStorage {
|
||||
public:
|
||||
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
|
||||
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
|
||||
// per package and 512 MiB, respectively.
|
||||
static constexpr size_t kMaxM = 4096;
|
||||
// Compile-time bounds on matrix columns to enable pre-allocating storage
|
||||
// and reusing it across `MatMul` calls.
|
||||
static constexpr size_t kMaxK = 64 * 1024;
|
||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
|
|
@ -348,8 +346,10 @@ class MMStorage {
|
|||
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
|
||||
Allocator& allocator = ctx.allocator;
|
||||
|
||||
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
|
||||
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd));
|
||||
// 0.5 GiB per package.
|
||||
pkg_A_[pkg_idx].reset(
|
||||
new MatStorageT<BF16>("pkg_A", Extents2D(kMaxBatchSize, kMaxK),
|
||||
allocator, MatPadding::kOdd));
|
||||
|
||||
if (allocator.ShouldBind()) {
|
||||
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
|
||||
|
|
@ -367,7 +367,7 @@ class MMStorage {
|
|||
// faster than on-the-fly when native BF16 is available: it only happens once,
|
||||
// not per B tile row, and the cache footprint is smaller.
|
||||
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.rows <= kMaxBatchSize);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
|
||||
extents.cols, pkg_A_[pkg_idx]->Stride());
|
||||
|
|
@ -527,8 +527,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
|||
#pragma pack(pop)
|
||||
|
||||
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TA,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
bool print_config);
|
||||
|
||||
|
|
|
|||
|
|
@ -120,9 +120,9 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||
// Dot() uses double-precision summation.
|
||||
double tolerance = 20 * norm * eps_f32;
|
||||
// If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to
|
||||
// BF16, so add extra tolerance.
|
||||
if (IsF32<TB>()) {
|
||||
// If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the
|
||||
// F32 to BF16, so add extra tolerance.
|
||||
if (IsF32<TA>() || IsF32<TB>()) {
|
||||
tolerance += 2 * max_abs * eps_bf16;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,10 @@ namespace gcpp {
|
|||
// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the
|
||||
// runtime `max_packages` does not exceed this. MatMul's outer per-package loop
|
||||
// is disabled if this is 1.
|
||||
constexpr size_t kMaxPackages = 1;
|
||||
HWY_INLINE_VAR constexpr size_t kMaxPackages = 1;
|
||||
|
||||
// TODO: extend to 16k after updating non_eos.
|
||||
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
|
||||
|
||||
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue