diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 16c9595..2a49349 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -248,17 +248,17 @@ struct InferenceArgs : public ArgsBase { runtime_config.max_generated_tokens = max_generated_tokens; runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size; - if (prefill_tbatch_size > MMStorage::kMaxM) { + if (prefill_tbatch_size > kMaxBatchSize) { HWY_ABORT( - "prefill_tbatch_size %zu > kMaxM %zu: specify a smaller value, " - "or increase the constant in MMStorage.\n", - prefill_tbatch_size, MMStorage::kMaxM); + "prefill_tbatch_size %zu > kMaxBatchSize %zu: specify a " + "smaller value, or increase kMaxBatchSize.\n", + prefill_tbatch_size, kMaxBatchSize); } - if (decode_qbatch_size > MMStorage::kMaxM) { + if (decode_qbatch_size > kMaxBatchSize) { HWY_ABORT( - "decode_qbatch_size %zu > kMaxM %zu: specify a smaller value, " - "or increase the constant in MMStorage.\n", - decode_qbatch_size, MMStorage::kMaxM); + "decode_qbatch_size %zu > kMaxBatchSize %zu: specify a " + "smaller value, or increase kMaxBatchSize.\n", + decode_qbatch_size, kMaxBatchSize); } runtime_config.temperature = temperature; diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 53dfb05..a9685e2 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -96,28 +96,27 @@ struct MMSetC {}; struct MMAddC {}; // Stores horizontal sums of up to 16 vectors via transpose. -template +template class MMStoreHorizontalSumsIntoC { public: static_assert(kNR == 4); // for `StoreInterleaved4` - // Computes horizontal sums of `kRowsAC x kNR` vectors and stores into - // `C` starting at `(row_c, col_c)`. - // + // Given 16 (`kRowsAC x kNR`) full vectors of 32-bit float, returns four + // 4-wide float vectors with their horizontal sums. // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a // transposed B row vector indexed by `c`. Their elements are thus a subset // of the terms of the dot product constituting the final `C[r, c]` result. // Thus we compute the horizontal sums of each `Crc`. The elements may be // permuted because we multiply bf16 via `ReorderWidenMulAccumulate`, but // this does not change their horizontal sum. - template , typename TC> - HWY_INLINE void operator()(DF df, // - VF C00, VF C01, VF C02, VF C03, // - VF C10, VF C11, VF C12, VF C13, // - VF C20, VF C21, VF C22, VF C23, // - VF C30, VF C31, VF C32, VF C33, // - const size_t row_c, const size_t col_c, - const MMArgs& args, RowPtrs C_rows) const { + template , class D4 = hn::Full128, + class V4 = hn::Vec> + HWY_INLINE void Reduce4x4(DF df, // + VF C00, VF C01, VF C02, VF C03, // + VF C10, VF C11, VF C12, VF C13, // + VF C20, VF C21, VF C22, VF C23, // + VF C30, VF C31, VF C32, VF C33, // + V4& sum0, V4& sum1, V4& sum2, V4& sum3) { HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; HWY_LANES_CONSTEXPR const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing @@ -133,14 +132,13 @@ class MMStoreHorizontalSumsIntoC { // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in // the elements of one V4. We have four independent rows `r`, hence the // code is effectively unrolled, which increases throughput. - const hn::CappedTag d4; - using V4 = hn::Vec; // Store to four elements per row of `partial`. // No loop is required because vectors are at least 4*32 bits. - V4 sum0 = MaybeLoad<0>(d4, N, buf); - V4 sum1 = MaybeLoad<1>(d4, N, buf); - V4 sum2 = MaybeLoad<2>(d4, N, buf); - V4 sum3 = MaybeLoad<3>(d4, N, buf); + const D4 d4; + sum0 = MaybeLoad<0>(d4, N, buf); + sum1 = MaybeLoad<1>(d4, N, buf); + sum2 = MaybeLoad<2>(d4, N, buf); + sum3 = MaybeLoad<3>(d4, N, buf); for (size_t lane = 1; lane < N; ++lane) { sum0 = MaybeAdd<0>(d4, N, sum0, buf + kNR * lane); @@ -148,13 +146,23 @@ class MMStoreHorizontalSumsIntoC { sum2 = MaybeAdd<2>(d4, N, sum2, buf + kNR * lane); sum3 = MaybeAdd<3>(d4, N, sum3, buf + kNR * lane); } + } + + // Scales the dot-product terms and adds bias (if present) and stores the + // four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is + // `MMSetC`, the vectors are written as-is (first call, or small K). + // Otherwise, they are partial sums and are accumulated into C. + template , class Tag, typename TC> + HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, + const size_t row_c, const size_t col_c, + const MMArgs& args, RowPtrs C_rows) const { const V4 vscale = hn::Set(d4, args.scale); HWY_ALIGN static constexpr float kZero[4] = {}; const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c); } private: @@ -191,18 +199,20 @@ class MMStoreHorizontalSumsIntoC { } template , - typename TC> + class Tag, typename TC> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, - VF4 vadd, RowPtrs C_rows, + VF4 vadd, Tag, RowPtrs C_rows, const size_t row_c, const size_t col_c) { if constexpr (kRow < kRowsAC) { TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; const hn::Rebind dc4; - if constexpr (kAdd) { + if constexpr (hwy::IsSame()) { vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value - } // else: add bias (only once, the first time we store to C) - + } else { + static_assert(hwy::IsSame()); + // vadd remains the bias (added once, the first time we store to C) + } const VF4 out = hn::MulAdd(sum, vscale, vadd); hn::Store(TCFromF32(dc4, out), dc4, pos); } @@ -215,9 +225,9 @@ class MMKernel { // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template - static HWY_INLINE void A2C0(const StridedView A_view, - const StridedViewBF& B_view, size_t mr, + template + static HWY_INLINE void A2C0(const StridedViewBF A_view, + const StridedViewBF B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { @@ -357,34 +367,18 @@ class MMKernel { } } - // For A=F32, B=BF16 without native BF16 dot product: one lane-crossing - // promotion is likely cheaper than AND+SHIFT for promoting odd/even BF. - // Caller already promoted B, so all inputs are F32. - template , HWY_IF_F32_D(DF)> - static HWY_INLINE void ElementwiseMulAccF32(DF df, VF a, VF b0, VF b1, VF b2, - VF b3, VF& C0, VF& C1, VF& C2, - VF& C3) { - HWY_DASSERT(!HWY_NATIVE_DOT_BF16); - C0 = hn::MulAdd(a, b0, C0); - C1 = hn::MulAdd(a, b1, C1); - C2 = hn::MulAdd(a, b2, C2); - C3 = hn::MulAdd(a, b3, C3); - } - // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. - // `B` is BF16, `A` and `C` can be F32 or BF16. - template - static HWY_INLINE void LoopKC(const StridedView A_view, - const StridedViewBF& B_view, size_t row_ac, + // `A` and `B` are always BF16, `C` can be F32 or BF16. + template + static HWY_INLINE void LoopKC(const StridedViewBF A_view, + const StridedViewBF B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; - - HWY_LANES_CONSTEXPR const size_t NA = hn::Lanes(hn::ScalableTag()); HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_DASSERT(kRowsAC <= kMaxMR); @@ -393,10 +387,10 @@ class MMKernel { // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. static_assert(kNR == 4); - const TA* HWY_RESTRICT ar0 = A_view.Row(imc + 0); - const TA* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; - const TA* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; - const TA* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; + const BF16* HWY_RESTRICT ar0 = A_view.Row(imc + 0); + const BF16* HWY_RESTRICT ar1 = kRowsAC > 1 ? A_view.Row(imc + 1) : nullptr; + const BF16* HWY_RESTRICT ar2 = kRowsAC > 2 ? A_view.Row(imc + 2) : nullptr; + const BF16* HWY_RESTRICT ar3 = kRowsAC > 3 ? A_view.Row(imc + 3) : nullptr; const BF16* HWY_RESTRICT br0 = B_view.Row(0); const BF16* HWY_RESTRICT br1 = B_view.Row(1); const BF16* HWY_RESTRICT br2 = B_view.Row(2); @@ -416,8 +410,6 @@ class MMKernel { C33 = hn::Zero(df); size_t ikc = 0; - // The loop step is always NBF: for non-native BF16 with TA=F32, this - // entails 2x unrolling, which helps a little. const HWY_LANES_CONSTEXPR size_t kc_step = NBF; if (kc >= kc_step) { HWY_UNROLL(1) @@ -432,10 +424,6 @@ class MMKernel { const VBF b2 = hn::LoadU(dbf, br2 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc); - // Should only get here if `A` is BF16, otherwise `DecompressA` would - // convert to BF16 and `A_view` points to that. - HWY_DASSERT(IsBF16()); - { const VBF a0 = hn::Load(dbf, ar0 + ikc); ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, @@ -457,102 +445,40 @@ class MMKernel { C33); } } else { // !HWY_NATIVE_DOT_BF16 - if constexpr (IsBF16()) { - // When both are BF16, it is better to load promote odd/even, - // because lane-crossing promotion for both might be bottlenecked on - // shuffles. - VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; - { - const VBF b0 = hn::LoadU(dbf, br0 + ikc); - const VBF b1 = hn::LoadU(dbf, br1 + ikc); - const VBF b2 = hn::LoadU(dbf, br2 + ikc); - const VBF b3 = hn::LoadU(dbf, br3 + ikc); - b0e = hn::PromoteEvenTo(df, b0); - b1e = hn::PromoteEvenTo(df, b1); - b2e = hn::PromoteEvenTo(df, b2); - b3e = hn::PromoteEvenTo(df, b3); - b0o = FastPromoteOddTo(df, b0); - b1o = FastPromoteOddTo(df, b1); - b2o = FastPromoteOddTo(df, b2); - b3o = FastPromoteOddTo(df, b3); - } - - // Two rows at a time so we have 8 separate dependency chains, - // sufficient for IPC=2 and 4-cycle latency. - { - const VBF a0 = hn::Load(dbf, ar0 + ikc); - const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; - ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C00, C01, C02, C03, C10, C11, - C12, C13); - } - if constexpr (kRowsAC > 2) { - const VBF a2 = hn::Load(dbf, ar2 + ikc); - const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; - ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C20, C21, C22, C23, C30, C31, - C32, C33); - } - } else { // IsF32(): promote BF to 2xF32, F32*F32. - // Full-vector loads are a bit faster on SKX than half + PromoteTo. + // When both are BF16, it is better to load promote odd/even, + // because lane-crossing promotion for both might be bottlenecked on + // shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { const VBF b0 = hn::LoadU(dbf, br0 + ikc); const VBF b1 = hn::LoadU(dbf, br1 + ikc); const VBF b2 = hn::LoadU(dbf, br2 + ikc); const VBF b3 = hn::LoadU(dbf, br3 + ikc); - const VF b00 = hn::PromoteLowerTo(df, b0); - const VF b10 = hn::PromoteLowerTo(df, b1); - const VF b20 = hn::PromoteLowerTo(df, b2); - const VF b30 = hn::PromoteLowerTo(df, b3); - const VF b01 = hn::PromoteUpperTo(df, b0); - const VF b11 = hn::PromoteUpperTo(df, b1); - const VF b21 = hn::PromoteUpperTo(df, b2); - const VF b31 = hn::PromoteUpperTo(df, b3); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } - { - const VF a00 = hn::Load(df, ar0 + ikc); - ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a10 = hn::Load(df, ar1 + ikc); - ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::Load(dbf, ar0 + ikc); + const VBF a1 = kRowsAC > 1 ? hn::Load(dbf, ar1 + ikc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C00, C01, C02, C03, C10, C11, C12, C13); - } - - // C00 is ready again. On SKX, this interleaved unrolling is faster - // than consuming all `b*1` at the end of the loop. - { - const VF a01 = hn::Load(df, ar0 + ikc + NA); - ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a11 = hn::Load(df, ar1 + ikc + NA); - ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, - C13); - } - - if constexpr (kRowsAC > 2) { - const VF a20 = hn::Load(df, ar2 + ikc); - ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a30 = hn::Load(df, ar3 + ikc); - ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::Load(dbf, ar2 + ikc); + const VBF a3 = kRowsAC > 3 ? hn::Load(dbf, ar3 + ikc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, + b3o, b3e, C20, C21, C22, C23, C30, C31, C32, C33); - } - - if constexpr (kRowsAC > 2) { - const VF a21 = hn::Load(df, ar2 + ikc + NA); - ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a31 = hn::Load(df, ar3 + ikc + NA); - ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, - C33); - } } } } @@ -569,10 +495,6 @@ class MMKernel { const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); - // Should only get here if `A` is BF16, otherwise `DecompressA` would - // convert to BF16 and `A_view` points to that. - HWY_DASSERT(IsBF16()); - { const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); ElementwiseMulAccNativeBF(dbf, a0, b0, b1, b2, b3, C00, C01, C02, @@ -594,104 +516,39 @@ class MMKernel { C33); } } else { // !HWY_NATIVE_DOT_BF16 - if constexpr (IsBF16()) { - // When both are BF16, it is better to load promote odd/even, because - // lane-crossing promotion for both might be bottlenecked on shuffles. - VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; - { - const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); - const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); - const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); - const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); - b0e = hn::PromoteEvenTo(df, b0); - b1e = hn::PromoteEvenTo(df, b1); - b2e = hn::PromoteEvenTo(df, b2); - b3e = hn::PromoteEvenTo(df, b3); - b0o = FastPromoteOddTo(df, b0); - b1o = FastPromoteOddTo(df, b1); - b2o = FastPromoteOddTo(df, b2); - b3o = FastPromoteOddTo(df, b3); - } - - // Two rows at a time so we have 8 separate dependency chains, - // sufficient for IPC=2 and 4-cycle latency. - { - const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); - const VBF a1 = - kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0; - ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C00, C01, C02, C03, C10, C11, C12, - C13); - } - if constexpr (kRowsAC > 2) { - const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); - const VBF a3 = - kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2; - ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, - b3o, b3e, C20, C21, C22, C23, C30, C31, C32, - C33); - } - } else { // IsF32(): promote half-B to F32, F32*F32. + // When both are BF16, it is better to load promote odd/even, because + // lane-crossing promotion for both might be bottlenecked on shuffles. + VF b0e, b1e, b2e, b3e, b0o, b1o, b2o, b3o; + { const VBF b0 = hn::LoadN(dbf, br0 + ikc, remaining_kc); const VBF b1 = hn::LoadN(dbf, br1 + ikc, remaining_kc); const VBF b2 = hn::LoadN(dbf, br2 + ikc, remaining_kc); const VBF b3 = hn::LoadN(dbf, br3 + ikc, remaining_kc); - const VF b00 = hn::PromoteLowerTo(df, b0); - const VF b10 = hn::PromoteLowerTo(df, b1); - const VF b20 = hn::PromoteLowerTo(df, b2); - const VF b30 = hn::PromoteLowerTo(df, b3); - const VF b01 = hn::PromoteUpperTo(df, b0); - const VF b11 = hn::PromoteUpperTo(df, b1); - const VF b21 = hn::PromoteUpperTo(df, b2); - const VF b31 = hn::PromoteUpperTo(df, b3); + b0e = hn::PromoteEvenTo(df, b0); + b1e = hn::PromoteEvenTo(df, b1); + b2e = hn::PromoteEvenTo(df, b2); + b3e = hn::PromoteEvenTo(df, b3); + b0o = FastPromoteOddTo(df, b0); + b1o = FastPromoteOddTo(df, b1); + b2o = FastPromoteOddTo(df, b2); + b3o = FastPromoteOddTo(df, b3); + } - const size_t remaining2 = remaining_kc <= NA ? 0 : remaining_kc - NA; - - { - const VF a00 = hn::LoadN(df, ar0 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a00, b00, b10, b20, b30, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a10 = hn::LoadN(df, ar1 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a10, b00, b10, b20, b30, C10, C11, C12, - C13); - } - - // C00 is ready again. On SKX, this interleaved unrolling is faster - // than consuming all `b*1` at the end of the loop. - { - const VF a01 = hn::LoadN(df, ar0 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a01, b01, b11, b21, b31, C00, C01, C02, - C03); - } - if constexpr (kRowsAC > 1) { - const VF a11 = hn::LoadN(df, ar1 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a11, b01, b11, b21, b31, C10, C11, C12, - C13); - } - - if constexpr (kRowsAC > 2) { - const VF a20 = hn::LoadN(df, ar2 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a20, b00, b10, b20, b30, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a30 = hn::LoadN(df, ar3 + ikc, remaining_kc); - ElementwiseMulAccF32(df, a30, b00, b10, b20, b30, C30, C31, C32, - C33); - } - - if constexpr (kRowsAC > 2) { - const VF a21 = hn::LoadN(df, ar2 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a21, b01, b11, b21, b31, C20, C21, C22, - C23); - } - if constexpr (kRowsAC > 3) { - const VF a31 = hn::LoadN(df, ar3 + ikc + NA, remaining2); - ElementwiseMulAccF32(df, a31, b01, b11, b21, b31, C30, C31, C32, - C33); - } + // Two rows at a time so we have 8 separate dependency chains, + // sufficient for IPC=2 and 4-cycle latency. + { + const VBF a0 = hn::LoadN(dbf, ar0 + ikc, remaining_kc); + const VBF a1 = + kRowsAC > 1 ? hn::LoadN(dbf, ar1 + ikc, remaining_kc) : a0; + ElementwiseMulAccEmuBF(dbf, a0, a1, b0o, b0e, b1o, b1e, b2o, b2e, b3o, + b3e, C00, C01, C02, C03, C10, C11, C12, C13); + } + if constexpr (kRowsAC > 2) { + const VBF a2 = hn::LoadN(dbf, ar2 + ikc, remaining_kc); + const VBF a3 = + kRowsAC > 3 ? hn::LoadN(dbf, ar3 + ikc, remaining_kc) : a2; + ElementwiseMulAccEmuBF(dbf, a2, a3, b0o, b0e, b1o, b1e, b2o, b2e, b3o, + b3e, C20, C21, C22, C23, C30, C31, C32, C33); } } } // remaining_kc != 0 @@ -699,16 +556,12 @@ class MMKernel { // This is a substantial fraction (about 1/3) of the total time, but is // called frequently, so do not add a profiler zone. - if constexpr (hwy::IsSame()) { - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } else { - static_assert(hwy::IsSame()); - MMStoreHorizontalSumsIntoC()( - df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C_rows); - } + MMStoreHorizontalSumsIntoC horz; + const hn::Full128 d4; + hn::Vec sum0, sum1, sum2, sum3; + horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, + C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); + horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows); } }; @@ -717,15 +570,6 @@ class MMKernel { // outer M/N/K loops, and calls `A2C0` which loops over the inner KC and MC. // Its member variables avoid long argument lists in Do*(). class MMPerPackage { - // Decompression is only required for F32 A and native BF16 dot products. - // If A is already BF16, we can use a view. Padding is not required - // because `LoopKC` can handle non-vector multiples. `LoopKC` also contains - // a special case for F32 `A` and non-native BF16 dot products. - template - static constexpr bool WantDecompressA() { - return HWY_NATIVE_DOT_BF16 && IsF32(); - } - public: MMPerPackage(const Extents2D A, const MMArgs& args, const MMConfig& config, size_t pkg_idx, size_t cluster_idx, const IndexRange& range_np) @@ -741,13 +585,6 @@ class MMPerPackage { inner_tasks_(config.InnerTasks()), line_bytes_(args.env->ctx.allocator.LineBytes()) {} - // The size of `A` that will actually be used, for purposes of choosing the - // autotuning candidates. Keep in sync with the `operator()` logic below. - template - static constexpr size_t ABytes() { - return WantDecompressA() ? sizeof(BF16) : sizeof(TA); - } - // B and maybe A are decompressed several call layers lower, but not all // member functions depend on TA/TB, so pass them as an argument instead of // templating the class. @@ -755,12 +592,16 @@ class MMPerPackage { HWY_NOINLINE void operator()(const MMParallelPolicyT& parallel_policy, const MatPtrT& A, const MatPtrT& B, RowPtrs C_rows) const { - if constexpr (WantDecompressA()) { + if constexpr (IsBF16()) { + // We can use a view, regardless of columns/padding, because `LoopKC` + // supports non-vector multiples. + DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); + } else { + // Always decompress. To reduce code size/compile time, we no longer + // support a separate F32 kernel; most A are already BF16. const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); DecompressA(A, A_view); DispatchOrder(parallel_policy, A_view, B, C_rows); - } else { - DispatchOrder(parallel_policy, View(A, 0, 0, A.Cols()), B, C_rows); } } @@ -937,7 +778,7 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. - const auto loop_nc = [&](const StridedViewBF& B_storage_view, + const auto loop_nc = [&](const StridedViewBF B_storage_view, const IndexRange& range_mc, const IndexRange& range_kc, const IndexRange& range_nc, @@ -1080,7 +921,7 @@ class MMPerPackage { template HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, const IndexRange& range_kc, - const StridedViewBF& B_view) const { + const StridedViewBF B_view) const { const hn::ScalableTag dbf; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); @@ -1229,7 +1070,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, if (HWY_UNLIKELY(!tuner.HasCandidates())) { // Ensure matrix dimensions match each other. HWY_ASSERT(K == B.Cols()); - HWY_ASSERT(M <= MMStorage::kMaxM); + HWY_ASSERT(M <= kMaxBatchSize); HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(N % kNR == 0); // Ensure A rows are vector-aligned. Neither `Stride` nor `IsPacked` are @@ -1241,9 +1082,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(hwy::IsAligned(A.Row(1), env.ctx.allocator.LineBytes())); } - tuner.SetCandidates( - MMCandidates(allocator, M, K, N, MMPerPackage::ABytes(), sizeof(TC), - kMaxMR, kNR, per_key.ranges_np, env.print_config)); + tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, + kNR, per_key.ranges_np, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index 711eac1..812fe99 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -64,21 +64,19 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, class GenerateCandidates { public: GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, - size_t sizeof_TA, size_t sizeof_TC, size_t max_mr, - size_t nr, const IndexRangePartition& ranges_np, - bool print_config) + size_t sizeof_TC, size_t max_mr, size_t nr, + const IndexRangePartition& ranges_np, bool print_config) : allocator_(allocator), M_(M), K_(K), N_(N), - sizeof_TA_(sizeof_TA), sizeof_TC_(sizeof_TC), max_mr_(max_mr), nr_(nr), // These influence kc/nc, but are also stored in `MMConfig` for // `RangesOf*`. Must be a vector multiple. The previous/next cache line // is likely still in L1, but we expect K > 1000 and might as well round - // up to the line size. Use BF16, not sizeof_TA, because B is BF16. + // up to the line size. Both A and B are BF16. kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), nc_multiple_(allocator.StepBytes() / sizeof_TC), ranges_np_(ranges_np), @@ -176,8 +174,8 @@ class GenerateCandidates { // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. const size_t bytes_ab = - allocator_.L1Bytes() * (sizeof_TA_ + sizeof(SfpStream)); - const size_t col_bytes = rows_a * sizeof_TA_ + nr_ * sizeof(BF16); + allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream)); + const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); kc_max = RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); @@ -224,9 +222,9 @@ class GenerateCandidates { // packed B. We want `mc * kc` elements of A to fit in L2, alongside // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // partial. - const size_t bytes_per_mc = kc * sizeof_TA_ + allocator_.LineBytes(); + const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes(); size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); - mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); + mc_max = HWY_MIN(mc_max, kMaxBatchSize); HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, M_); mc_max = hwy::RoundDownTo(mc_max, mr); @@ -340,7 +338,6 @@ class GenerateCandidates { const size_t M_; const size_t K_; const size_t N_; - const size_t sizeof_TA_; const size_t sizeof_TC_; const size_t max_mr_; @@ -358,12 +355,12 @@ class GenerateCandidates { // Facade to avoid exposing `GenerateCandidates` in the header. std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TA, - size_t sizeof_TC, size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) { - return GenerateCandidates(allocator, M, K, N, sizeof_TA, sizeof_TC, max_mr, - nr, ranges_np, print_config)(); + return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, + ranges_np, print_config)(); } // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote @@ -409,7 +406,7 @@ MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) { char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); - row_ptrs.push_back(hwy::AllocateAligned(MMStorage::kMaxM)); // C + row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { diff --git a/ops/matmul.h b/ops/matmul.h index 11262bc..70c7d20 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -330,10 +330,8 @@ using StridedViewD = StridedView; // Per-package storage for packed A. class MMStorage { public: - // Compile-time bounds on matrix dimensions to enable pre-allocating storage - // and reusing it across `MatMul` calls. The resulting allocations are 256 MiB - // per package and 512 MiB, respectively. - static constexpr size_t kMaxM = 4096; + // Compile-time bounds on matrix columns to enable pre-allocating storage + // and reusing it across `MatMul` calls. static constexpr size_t kMaxK = 64 * 1024; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. @@ -348,8 +346,10 @@ class MMStorage { MMNestedParallelPolicy::ForPkg(ctx, kMaxPackages, [&](size_t pkg_idx) { Allocator& allocator = ctx.allocator; - pkg_A_[pkg_idx].reset(new MatStorageT( - "pkg_A", Extents2D(kMaxM, kMaxK), allocator, MatPadding::kOdd)); + // 0.5 GiB per package. + pkg_A_[pkg_idx].reset( + new MatStorageT("pkg_A", Extents2D(kMaxBatchSize, kMaxK), + allocator, MatPadding::kOdd)); if (allocator.ShouldBind()) { const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); @@ -367,7 +367,7 @@ class MMStorage { // faster than on-the-fly when native BF16 is available: it only happens once, // not per B tile row, and the cache footprint is smaller. StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { - HWY_DASSERT(extents.rows <= kMaxM); + HWY_DASSERT(extents.rows <= kMaxBatchSize); HWY_DASSERT(extents.cols <= kMaxK); return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), extents.cols, pkg_A_[pkg_idx]->Stride()); @@ -527,8 +527,8 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) std::vector MMCandidates(const Allocator& allocator, size_t M, - size_t K, size_t N, size_t sizeof_TA, - size_t sizeof_TC, size_t max_mr, size_t nr, + size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config); diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index dc6f559..665e337 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -120,9 +120,9 @@ void AssertClose(const MatPtrT& A, const MatPtrT& B, const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); // Dot() uses double-precision summation. double tolerance = 20 * norm * eps_f32; - // If B is F32, Dot() promotes F32 or even F64, but MatMul demotes the F32 to - // BF16, so add extra tolerance. - if (IsF32()) { + // If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the + // F32 to BF16, so add extra tolerance. + if (IsF32() || IsF32()) { tolerance += 2 * max_abs * eps_bf16; } diff --git a/util/basics.h b/util/basics.h index 30864b2..13d0362 100644 --- a/util/basics.h +++ b/util/basics.h @@ -33,7 +33,10 @@ namespace gcpp { // Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the // runtime `max_packages` does not exceed this. MatMul's outer per-package loop // is disabled if this is 1. -constexpr size_t kMaxPackages = 1; +HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; + +// TODO: extend to 16k after updating non_eos. +HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; enum class Tristate : int32_t { kFalse = 0, kTrue = 1, kDefault = -1 };