Simplify MatMul: remove F32 special case (build time)

Also move kMaxM into separate kMaxBatchSize

PiperOrigin-RevId: 802086590
This commit is contained in:
Jan Wassenberg 2025-09-02 04:28:49 -07:00 committed by Copybara-Service
parent 1e3c853e80
commit b7b3d353db
6 changed files with 157 additions and 317 deletions

View File

@ -248,17 +248,17 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
runtime_config.max_generated_tokens = max_generated_tokens; runtime_config.max_generated_tokens = max_generated_tokens;
runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.prefill_tbatch_size = prefill_tbatch_size;
runtime_config.decode_qbatch_size = decode_qbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size;
if (prefill_tbatch_size > MMStorage::kMaxM) { if (prefill_tbatch_size > kMaxBatchSize) {
HWY_ABORT( HWY_ABORT(
"prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " "prefill_tbatch_size %zu > kMaxBatchSize %zu: specify a "
"or increase the constant in MMStorage.\n", "smaller value, or increase kMaxBatchSize.\n",
prefill_tbatch_size, MMStorage::kMaxM); prefill_tbatch_size, kMaxBatchSize);
} }
if (decode_qbatch_size > MMStorage::kMaxM) { if (decode_qbatch_size > kMaxBatchSize) {
HWY_ABORT( HWY_ABORT(
"decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " "decode_qbatch_size %zu > kMaxBatchSize %zu: specify a "
"or increase the constant in MMStorage.\n", "smaller value, or increase kMaxBatchSize.\n",
decode_qbatch_size, MMStorage::kMaxM); decode_qbatch_size, kMaxBatchSize);
} }
runtime_config.temperature = temperature; runtime_config.temperature = temperature;

View File

@ -96,28 +96,27 @@ struct MMSetC {};
struct MMAddC {}; struct MMAddC {};
// Stores horizontal sums of up to 16 vectors via transpose. // Stores horizontal sums of up to 16 vectors via transpose.
template <size_t kRowsAC, bool kAdd> template <size_t kRowsAC>
class MMStoreHorizontalSumsIntoC { class MMStoreHorizontalSumsIntoC {
public: public:
static_assert(kNR == 4); // for `StoreInterleaved4` static_assert(kNR == 4); // for `StoreInterleaved4`
// Computes horizontal sums of `kRowsAC x kNR` vectors and stores into // Given 16 (`kRowsAC x kNR`) full vectors of 32-bit float, returns four
// `C` starting at `(row_c, col_c)`. // 4-wide float vectors with their horizontal sums.
//
// `Crc` are the 16 combinations of an A row vector indexed by `r`, times a // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a
// transposed B row vector indexed by `c`. Their elements are thus a subset // transposed B row vector indexed by `c`. Their elements are thus a subset
// of the terms of the dot product constituting the final `C[r, c]` result. // of the terms of the dot product constituting the final `C[r, c]` result.
// Thus we compute the horizontal sums of each `Crc`. The elements may be // Thus we compute the horizontal sums of each `Crc`. The elements may be
// permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but
// this does not change their horizontal sum. // this does not change their horizontal sum.
template <class DF, class VF = hn::Vec<DF>, typename TC> template <class DF, class VF = hn::Vec<DF>, class D4 = hn::Full128<float>,
HWY_INLINE void operator()(DF df, // class V4 = hn::Vec<D4>>
VF C00, VF C01, VF C02, VF C03, // HWY_INLINE void Reduce4x4(DF df, //
VF C10, VF C11, VF C12, VF C13, // VF C00, VF C01, VF C02, VF C03, //
VF C20, VF C21, VF C22, VF C23, // VF C10, VF C11, VF C12, VF C13, //
VF C30, VF C31, VF C32, VF C33, // VF C20, VF C21, VF C22, VF C23, //
const size_t row_c, const size_t col_c, VF C30, VF C31, VF C32, VF C33, //
const MMArgs& args, RowPtrs<TC> C_rows) const { V4& sum0, V4& sum1, V4& sum2, V4& sum3) {
HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df); HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing // Horizontal reductions (`ReduceSum`) are rather expensive, entailing
@ -133,14 +132,13 @@ class MMStoreHorizontalSumsIntoC {
// Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in
// the elements of one V4. We have four independent rows `r`, hence the // the elements of one V4. We have four independent rows `r`, hence the
// code is effectively unrolled, which increases throughput. // code is effectively unrolled, which increases throughput.
const hn::CappedTag<float, kNR> d4;
using V4 = hn::Vec<decltype(d4)>;
// Store to four elements per row of `partial`. // Store to four elements per row of `partial`.
// No loop is required because vectors are at least 4*32 bits. // No loop is required because vectors are at least 4*32 bits.
V4 sum0 = MaybeLoad<0>(d4, N, buf); const D4 d4;
V4 sum1 = MaybeLoad<1>(d4, N, buf); sum0 = MaybeLoad<0>(d4, N, buf);
V4 sum2 = MaybeLoad<2>(d4, N, buf); sum1 = MaybeLoad<1>(d4, N, buf);
V4 sum3 = MaybeLoad<3>(d4, N, buf); sum2 = MaybeLoad<2>(d4, N, buf);
sum3 = MaybeLoad<3>(d4, N, buf);
for (size_t lane = 1; lane < N; ++lane) { for (size_t lane = 1; lane < N; ++lane) {
sum0 = MaybeAdd<0>(d4, N, sum0, buf + kNR * lane); sum0 = MaybeAdd<0>(d4, N, sum0, buf + kNR * lane);
@ -148,13 +146,23 @@ class MMStoreHorizontalSumsIntoC {
sum2 = MaybeAdd<2>(d4, N, sum2, buf + kNR * lane); sum2 = MaybeAdd<2>(d4, N, sum2, buf + kNR * lane);
sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane); sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane);
} }
}
// Scales the dot-product terms and adds bias (if present) and stores the
// four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is
// `MMSetC`, the vectors are written as-is (first call, or small K).
// Otherwise, they are partial sums and are accumulated into C.
template <class D4, class V4 = hn::Vec<D4>, class Tag, typename TC>
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag,
const size_t row_c, const size_t col_c,
const MMArgs& args, RowPtrs<TC> C_rows) const {
const V4 vscale = hn::Set(d4, args.scale); const V4 vscale = hn::Set(d4, args.scale);
HWY_ALIGN static constexpr float kZero[4] = {}; HWY_ALIGN static constexpr float kZero[4] = {};
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero);
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c); MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c);
} }
private: private:
@ -191,18 +199,20 @@ class MMStoreHorizontalSumsIntoC {
} }
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>, template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
typename TC> class Tag, typename TC>
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
VF4 vadd, RowPtrs<TC> C_rows, VF4 vadd, Tag, RowPtrs<TC> C_rows,
const size_t row_c, const size_t row_c,
const size_t col_c) { const size_t col_c) {
if constexpr (kRow < kRowsAC) { if constexpr (kRow < kRowsAC) {
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
const hn::Rebind<TC, DF4> dc4; const hn::Rebind<TC, DF4> dc4;
if constexpr (kAdd) { if constexpr (hwy::IsSame<Tag, MMAddC>()) {
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
} // else: add bias (only once, the first time we store to C) } else {
static_assert(hwy::IsSame<Tag, MMSetC>());
// vadd remains the bias (added once, the first time we store to C)
}
const VF4 out = hn::MulAdd(sum, vscale, vadd); const VF4 out = hn::MulAdd(sum, vscale, vadd);
hn::Store(TCFromF32(dc4, out), dc4, pos); hn::Store(TCFromF32(dc4, out), dc4, pos);
} }
@ -215,9 +225,9 @@ class MMKernel {
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
template <class Tag, typename TA, typename TC> template <class Tag, typename TC>
static HWY_INLINE void A2C0(const StridedView<TA> A_view, static HWY_INLINE void A2C0(const StridedViewBF A_view,
const StridedViewBF& B_view, size_t mr, const StridedViewBF B_view, size_t mr,
const IndexRange& range_mc, const size_t row_b, const IndexRange& range_mc, const size_t row_b,
size_t kc, Tag tag, const MMArgs& args, size_t kc, Tag tag, const MMArgs& args,
RowPtrs<TC> C_rows) { RowPtrs<TC> C_rows) {
@ -357,34 +367,18 @@ class MMKernel {
} }
} }
// For A=F32, B=BF16 without native BF16 dot product: one lane-crossing
// promotion is likely cheaper than AND+SHIFT for promoting odd/even BF.
// Caller already promoted B, so all inputs are F32.
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2,
VF b3, VF& C0, VF& C1, VF& C2,
VF& C3) {
HWY_DASSERT(!HWY_NATIVE_DOT_BF16);
C0 = hn::MulAdd(a, b0, C0);
C1 = hn::MulAdd(a, b1, C1);
C2 = hn::MulAdd(a, b2, C2);
C3 = hn::MulAdd(a, b3, C3);
}
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
// Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`.
// `B` is BF16, `A` and `C` can be F32 or BF16. // `A` and `B` are always BF16, `C` can be F32 or BF16.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TA, typename TC> template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
static HWY_INLINE void LoopKC(const StridedView<TA> A_view, static HWY_INLINE void LoopKC(const StridedViewBF A_view,
const StridedViewBF& B_view, size_t row_ac, const StridedViewBF B_view, size_t row_ac,
size_t imc, size_t col_c, size_t kc, Tag tag, size_t imc, size_t col_c, size_t kc, Tag tag,
const MMArgs& args, RowPtrs<TC> C_rows) { const MMArgs& args, RowPtrs<TC> C_rows) {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>; using VBF = hn::Vec<decltype(dbf)>;
HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag<TA>());
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
HWY_DASSERT(kRowsAC <= kMaxMR); HWY_DASSERT(kRowsAC <= kMaxMR);
@ -393,10 +387,10 @@ class MMKernel {
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
static_assert(kNR == 4); static_assert(kNR == 4);
const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0); const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0);
const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr;
const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr;
const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr;
const BF16* HWY_RESTRICT br0 = B_view.Row(0); const BF16* HWY_RESTRICT br0 = B_view.Row(0);
const BF16* HWY_RESTRICT br1 = B_view.Row(1); const BF16* HWY_RESTRICT br1 = B_view.Row(1);
const BF16* HWY_RESTRICT br2 = B_view.Row(2); const BF16* HWY_RESTRICT br2 = B_view.Row(2);
@ -416,8 +410,6 @@ class MMKernel {
C33 = hn::Zero(df); C33 = hn::Zero(df);
size_t ikc = 0; size_t ikc = 0;
// The loop step is always NBF: for non-native BF16 with TA=F32, this
// entails 2x unrolling, which helps a little.
const HWY_LANES_CONSTEXPR size_t kc_step = NBF; const HWY_LANES_CONSTEXPR size_t kc_step = NBF;
if (kc >= kc_step) { if (kc >= kc_step) {
HWY_UNROLL(1) HWY_UNROLL(1)
@ -432,10 +424,6 @@ class MMKernel {
const VBF b2 = hn::LoadU(dbf, br2 + ikc); const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that.
HWY_DASSERT(IsBF16<TA>());
{ {
const VBF a0 = hn::Load(dbf, ar0 + ikc); const VBF a0 = hn::Load(dbf, ar0 + ikc);
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
@ -457,102 +445,40 @@ class MMKernel {
C33); C33);
} }
} else { // !HWY_NATIVE_DOT_BF16 } else { // !HWY_NATIVE_DOT_BF16
if constexpr (IsBF16<TA>()) { // When both are BF16, it is better to load promote odd/even,
// When both are BF16, it is better to load promote odd/even, // because lane-crossing promotion for both might be bottlenecked on
// because lane-crossing promotion for both might be bottlenecked on // shuffles.
// shuffles. VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; {
{
const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
b3e = hn::PromoteEvenTo(df, b3);
b0o = FastPromoteOddTo(df, b0);
b1o = FastPromoteOddTo(df, b1);
b2o = FastPromoteOddTo(df, b2);
b3o = FastPromoteOddTo(df, b3);
}
// Two rows at a time so we have 8 separate dependency chains,
// sufficient for IPC=2 and 4-cycle latency.
{
const VBF a0 = hn::Load(dbf, ar0 + ikc);
const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C00, C01, C02, C03, C10, C11,
C12, C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::Load(dbf, ar2 + ikc);
const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C20, C21, C22, C23, C30, C31,
C32, C33);
}
} else { // IsF32<TA>(): promote BF to 2xF32, F32*F32.
// Full-vector loads are a bit faster on SKX than half + PromoteTo.
const VBF b0 = hn::LoadU(dbf, br0 + ikc); const VBF b0 = hn::LoadU(dbf, br0 + ikc);
const VBF b1 = hn::LoadU(dbf, br1 + ikc); const VBF b1 = hn::LoadU(dbf, br1 + ikc);
const VBF b2 = hn::LoadU(dbf, br2 + ikc); const VBF b2 = hn::LoadU(dbf, br2 + ikc);
const VBF b3 = hn::LoadU(dbf, br3 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc);
const VF b00 = hn::PromoteLowerTo(df, b0); b0e = hn::PromoteEvenTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1); b1e = hn::PromoteEvenTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2); b2e = hn::PromoteEvenTo(df, b2);
const VF b30 = hn::PromoteLowerTo(df, b3); b3e = hn::PromoteEvenTo(df, b3);
const VF b01 = hn::PromoteUpperTo(df, b0); b0o = FastPromoteOddTo(df, b0);
const VF b11 = hn::PromoteUpperTo(df, b1); b1o = FastPromoteOddTo(df, b1);
const VF b21 = hn::PromoteUpperTo(df, b2); b2o = FastPromoteOddTo(df, b2);
const VF b31 = hn::PromoteUpperTo(df, b3); b3o = FastPromoteOddTo(df, b3);
}
{ // Two rows at a time so we have 8 separate dependency chains,
const VF a00 = hn::Load(df, ar0 + ikc); // sufficient for IPC=2 and 4-cycle latency.
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, {
C03); const VBF a0 = hn::Load(dbf, ar0 + ikc);
} const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0;
if constexpr (kRowsAC > 1) { ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
const VF a10 = hn::Load(df, ar1 + ikc); b3o, b3e, C00, C01, C02, C03, C10, C11, C12,
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12,
C13); C13);
} }
if constexpr (kRowsAC > 2) {
// C00 is ready again. On SKX, this interleaved unrolling is faster const VBF a2 = hn::Load(dbf, ar2 + ikc);
// than consuming all `b*1` at the end of the loop. const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2;
{ ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
const VF a01 = hn::Load(df, ar0 + ikc + NA); b3o, b3e, C20, C21, C22, C23, C30, C31, C32,
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VF a11 = hn::Load(df, ar1 + ikc + NA);
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VF a20 = hn::Load(df, ar2 + ikc);
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a30 = hn::Load(df, ar3 + ikc);
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
C33); C33);
}
if constexpr (kRowsAC > 2) {
const VF a21 = hn::Load(df, ar2 + ikc + NA);
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a31 = hn::Load(df, ar3 + ikc + NA);
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
C33);
}
} }
} }
} }
@ -569,10 +495,6 @@ class MMKernel {
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
// Should only get here if `A` is BF16, otherwise `DecompressA` would
// convert to BF16 and `A_view` points to that.
HWY_DASSERT(IsBF16<TA>());
{ {
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02,
@ -594,104 +516,39 @@ class MMKernel {
C33); C33);
} }
} else { // !HWY_NATIVE_DOT_BF16 } else { // !HWY_NATIVE_DOT_BF16
if constexpr (IsBF16<TA>()) { // When both are BF16, it is better to load promote odd/even, because
// When both are BF16, it is better to load promote odd/even, because // lane-crossing promotion for both might be bottlenecked on shuffles.
// lane-crossing promotion for both might be bottlenecked on shuffles. VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o;
VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; {
{
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
b0e = hn::PromoteEvenTo(df, b0);
b1e = hn::PromoteEvenTo(df, b1);
b2e = hn::PromoteEvenTo(df, b2);
b3e = hn::PromoteEvenTo(df, b3);
b0o = FastPromoteOddTo(df, b0);
b1o = FastPromoteOddTo(df, b1);
b2o = FastPromoteOddTo(df, b2);
b3o = FastPromoteOddTo(df, b3);
}
// Two rows at a time so we have 8 separate dependency chains,
// sufficient for IPC=2 and 4-cycle latency.
{
const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
const VBF a1 =
kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0;
ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C00, C01, C02, C03, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
const VBF a3 =
kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2;
ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e,
b3o, b3e, C20, C21, C22, C23, C30, C31, C32,
C33);
}
} else { // IsF32<TA>(): promote half-B to F32, F32*F32.
const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc);
const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc);
const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc);
const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc);
const VF b00 = hn::PromoteLowerTo(df, b0); b0e = hn::PromoteEvenTo(df, b0);
const VF b10 = hn::PromoteLowerTo(df, b1); b1e = hn::PromoteEvenTo(df, b1);
const VF b20 = hn::PromoteLowerTo(df, b2); b2e = hn::PromoteEvenTo(df, b2);
const VF b30 = hn::PromoteLowerTo(df, b3); b3e = hn::PromoteEvenTo(df, b3);
const VF b01 = hn::PromoteUpperTo(df, b0); b0o = FastPromoteOddTo(df, b0);
const VF b11 = hn::PromoteUpperTo(df, b1); b1o = FastPromoteOddTo(df, b1);
const VF b21 = hn::PromoteUpperTo(df, b2); b2o = FastPromoteOddTo(df, b2);
const VF b31 = hn::PromoteUpperTo(df, b3); b3o = FastPromoteOddTo(df, b3);
}
const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA; // Two rows at a time so we have 8 separate dependency chains,
// sufficient for IPC=2 and 4-cycle latency.
{ {
const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc); const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, const VBF a1 =
C03); kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0;
} ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
if constexpr (kRowsAC > 1) { b3e, C00, C01, C02, C03, C10, C11, C12, C13);
const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc); }
ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, if constexpr (kRowsAC > 2) {
C13); const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc);
} const VBF a3 =
kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2;
// C00 is ready again. On SKX, this interleaved unrolling is faster ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o,
// than consuming all `b*1` at the end of the loop. b3e, C20, C21, C22, C23, C30, C31, C32, C33);
{
const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02,
C03);
}
if constexpr (kRowsAC > 1) {
const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12,
C13);
}
if constexpr (kRowsAC > 2) {
const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc);
ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32,
C33);
}
if constexpr (kRowsAC > 2) {
const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22,
C23);
}
if constexpr (kRowsAC > 3) {
const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2);
ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32,
C33);
}
} }
} }
} // remaining_kc != 0 } // remaining_kc != 0
@ -699,16 +556,12 @@ class MMKernel {
// This is a substantial fraction (about 1/3) of the total time, but is // This is a substantial fraction (about 1/3) of the total time, but is
// called frequently, so do not add a profiler zone. // called frequently, so do not add a profiler zone.
if constexpr (hwy::IsSame<Tag, MMSetC>()) { MMStoreHorizontalSumsIntoC<kRowsAC> horz;
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()( const hn::Full128<float> d4;
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
C31, C32, C33, row_ac, col_c, args, C_rows); horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
} else { C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
static_assert(hwy::IsSame<Tag, MMAddC>()); horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows);
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args, C_rows);
}
} }
}; };
@ -717,15 +570,6 @@ class MMKernel {
// outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC.
// Its member variables avoid long argument lists in Do*(). // Its member variables avoid long argument lists in Do*().
class MMPerPackage { class MMPerPackage {
// Decompression is only required for F32 A and native BF16 dot products.
// If A is already BF16, we can use a view. Padding is not required
// because `LoopKC` can handle non-vector multiples. `LoopKC` also contains
// a special case for F32 `A` and non-native BF16 dot products.
template <typename TA>
static constexpr bool WantDecompressA() {
return HWY_NATIVE_DOT_BF16 && IsF32<TA>();
}
public: public:
MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config,
size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np)
@ -741,13 +585,6 @@ class MMPerPackage {
inner_tasks_(config.InnerTasks()), inner_tasks_(config.InnerTasks()),
line_bytes_(args.env->ctx.allocator.LineBytes()) {} line_bytes_(args.env->ctx.allocator.LineBytes()) {}
// The size of `A` that will actually be used, for purposes of choosing the
// autotuning candidates. Keep in sync with the `operator()` logic below.
template <typename TA>
static constexpr size_t ABytes() {
return WantDecompressA<TA>() ? sizeof(BF16) : sizeof(TA);
}
// B and maybe A are decompressed several call layers lower, but not all // B and maybe A are decompressed several call layers lower, but not all
// member functions depend on TA/TB, so pass them as an argument instead of // member functions depend on TA/TB, so pass them as an argument instead of
// templating the class. // templating the class.
@ -755,12 +592,16 @@ class MMPerPackage {
HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy,
const MatPtrT<TA>& A, const MatPtrT<TB>& B, const MatPtrT<TA>& A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const { RowPtrs<TC> C_rows) const {
if constexpr (WantDecompressA<TA>()) { if constexpr (IsBF16<TA>()) {
// We can use a view, regardless of columns/padding, because `LoopKC`
// supports non-vector multiples.
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
} else {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16.
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
DecompressA<MMParallelPolicyT>(A, A_view); DecompressA<MMParallelPolicyT>(A, A_view);
DispatchOrder(parallel_policy, A_view, B, C_rows); DispatchOrder(parallel_policy, A_view, B, C_rows);
} else {
DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows);
} }
} }
@ -937,7 +778,7 @@ class MMPerPackage {
// Sequential loop over NC/MC/KC, for when the M/N loops are // Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read // already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
const auto loop_nc = [&](const StridedViewBF& B_storage_view, const auto loop_nc = [&](const StridedViewBF B_storage_view,
const IndexRange& range_mc, const IndexRange& range_mc,
const IndexRange& range_kc, const IndexRange& range_kc,
const IndexRange& range_nc, const IndexRange& range_nc,
@ -1080,7 +921,7 @@ class MMPerPackage {
template <typename TB> template <typename TB>
HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b, HWY_INLINE StridedViewBF DecompressB(const MatPtrT<TB>& B, const size_t row_b,
const IndexRange& range_kc, const IndexRange& range_kc,
const StridedViewBF& B_view) const { const StridedViewBF B_view) const {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
@ -1229,7 +1070,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
if (HWY_UNLIKELY(!tuner.HasCandidates())) { if (HWY_UNLIKELY(!tuner.HasCandidates())) {
// Ensure matrix dimensions match each other. // Ensure matrix dimensions match each other.
HWY_ASSERT(K == B.Cols()); HWY_ASSERT(K == B.Cols());
HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(M <= kMaxBatchSize);
HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(K <= MMStorage::kMaxK);
HWY_ASSERT(N % kNR == 0); HWY_ASSERT(N % kNR == 0);
// Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are
@ -1241,9 +1082,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes())); HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes()));
} }
tuner.SetCandidates( tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
MMCandidates(allocator, M, K, N, MMPerPackage::ABytes<TA>(), sizeof(TC), kNR, per_key.ranges_np, env.print_config));
kMaxMR, kNR, per_key.ranges_np, env.print_config));
} }
const MMConfig& cfg = tuner.NextConfig(); const MMConfig& cfg = tuner.NextConfig();

View File

@ -64,21 +64,19 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
class GenerateCandidates { class GenerateCandidates {
public: public:
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TA, size_t sizeof_TC, size_t max_mr, size_t sizeof_TC, size_t max_mr, size_t nr,
size_t nr, const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np, bool print_config)
bool print_config)
: allocator_(allocator), : allocator_(allocator),
M_(M), M_(M),
K_(K), K_(K),
N_(N), N_(N),
sizeof_TA_(sizeof_TA),
sizeof_TC_(sizeof_TC), sizeof_TC_(sizeof_TC),
max_mr_(max_mr), max_mr_(max_mr),
nr_(nr), nr_(nr),
// These influence kc/nc, but are also stored in `MMConfig` for // These influence kc/nc, but are also stored in `MMConfig` for
// `RangesOf*`. Must be a vector multiple. The previous/next cache line // `RangesOf*`. Must be a vector multiple. The previous/next cache line
// is likely still in L1, but we expect K > 1000 and might as well round // is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size. Use BF16, not sizeof_TA, because B is BF16. // up to the line size. Both A and B are BF16.
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
nc_multiple_(allocator.StepBytes() / sizeof_TC), nc_multiple_(allocator.StepBytes() / sizeof_TC),
ranges_np_(ranges_np), ranges_np_(ranges_np),
@ -176,8 +174,8 @@ class GenerateCandidates {
// size. This results in an overestimate, and the loop below will propose // size. This results in an overestimate, and the loop below will propose
// the next few smaller values for the autotuner to evaluate. // the next few smaller values for the autotuner to evaluate.
const size_t bytes_ab = const size_t bytes_ab =
allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream)); allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream));
const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16); const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
kc_max = kc_max =
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
@ -224,9 +222,9 @@ class GenerateCandidates {
// packed B. We want `mc * kc` elements of A to fit in L2, alongside // packed B. We want `mc * kc` elements of A to fit in L2, alongside
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
// partial. // partial.
const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes(); const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); mc_max = HWY_MIN(mc_max, kMaxBatchSize);
HWY_DASSERT(mc_max != 0); HWY_DASSERT(mc_max != 0);
mc_max = HWY_MIN(mc_max, M_); mc_max = HWY_MIN(mc_max, M_);
mc_max = hwy::RoundDownTo(mc_max, mr); mc_max = hwy::RoundDownTo(mc_max, mr);
@ -340,7 +338,6 @@ class GenerateCandidates {
const size_t M_; const size_t M_;
const size_t K_; const size_t K_;
const size_t N_; const size_t N_;
const size_t sizeof_TA_;
const size_t sizeof_TC_; const size_t sizeof_TC_;
const size_t max_mr_; const size_t max_mr_;
@ -358,12 +355,12 @@ class GenerateCandidates {
// Facade to avoid exposing `GenerateCandidates` in the header. // Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M, std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TA, size_t K, size_t N, size_t sizeof_TC,
size_t sizeof_TC, size_t max_mr, size_t nr, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np,
bool print_config) { bool print_config) {
return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr, return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
nr, ranges_np, print_config)(); ranges_np, print_config)();
} }
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
@ -409,7 +406,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
char cpu100[100]; char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100); have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(MMStorage::kMaxM)); // C row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
} }
void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {

View File

@ -330,10 +330,8 @@ using StridedViewD = StridedView<double>;
// Per-package storage for packed A. // Per-package storage for packed A.
class MMStorage { class MMStorage {
public: public:
// Compile-time bounds on matrix dimensions to enable pre-allocating storage // Compile-time bounds on matrix columns to enable pre-allocating storage
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB // and reusing it across `MatMul` calls.
// per package and 512 MiB, respectively.
static constexpr size_t kMaxM = 4096;
static constexpr size_t kMaxK = 64 * 1024; static constexpr size_t kMaxK = 64 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row // Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
@ -348,8 +346,10 @@ class MMStorage {
MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) {
Allocator& allocator = ctx.allocator; Allocator& allocator = ctx.allocator;
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>( // 0.5 GiB per package.
"pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); pkg_A_[pkg_idx].reset(
new MatStorageT<BF16>("pkg_A", Extents2D(kMaxBatchSize, kMaxK),
allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) { if (allocator.ShouldBind()) {
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
@ -367,7 +367,7 @@ class MMStorage {
// faster than on-the-fly when native BF16 is available: it only happens once, // faster than on-the-fly when native BF16 is available: it only happens once,
// not per B tile row, and the cache footprint is smaller. // not per B tile row, and the cache footprint is smaller.
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.rows <= kMaxBatchSize);
HWY_DASSERT(extents.cols <= kMaxK); HWY_DASSERT(extents.cols <= kMaxK);
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)), return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
extents.cols, pkg_A_[pkg_idx]->Stride()); extents.cols, pkg_A_[pkg_idx]->Stride());
@ -527,8 +527,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop) #pragma pack(pop)
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M, std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TA, size_t K, size_t N, size_t sizeof_TC,
size_t sizeof_TC, size_t max_mr, size_t nr, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np,
bool print_config); bool print_config);

View File

@ -120,9 +120,9 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>()); const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
// Dot() uses double-precision summation. // Dot() uses double-precision summation.
double tolerance = 20 * norm * eps_f32; double tolerance = 20 * norm * eps_f32;
// If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to // If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the
// BF16, so add extra tolerance. // F32 to BF16, so add extra tolerance.
if (IsF32<TB>()) { if (IsF32<TA>() || IsF32<TB>()) {
tolerance += 2 * max_abs * eps_bf16; tolerance += 2 * max_abs * eps_bf16;
} }

View File

@ -33,7 +33,10 @@ namespace gcpp {
// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the // Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the
// runtime `max_packages` does not exceed this. MatMul's outer per-package loop // runtime `max_packages` does not exceed this. MatMul's outer per-package loop
// is disabled if this is 1. // is disabled if this is 1.
constexpr size_t kMaxPackages = 1; HWY_INLINE_VAR constexpr size_t kMaxPackages = 1;
// TODO: extend to 16k after updating non_eos.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;
enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 }; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };