Update MatMul comments, removing mention of partial.

PiperOrigin-RevId: 804872289
This commit is contained in:
Jan Wassenberg 2025-09-09 05:56:57 -07:00 committed by Copybara-Service
parent a5ab99e4ba
commit 34ceee6c30
2 changed files with 9 additions and 11 deletions

View File

@ -132,7 +132,7 @@ class MMStoreHorizontalSumsIntoC {
// Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in
// the elements of one V4. We have four independent rows `r`, hence the // the elements of one V4. We have four independent rows `r`, hence the
// code is effectively unrolled, which increases throughput. // code is effectively unrolled, which increases throughput.
// Store to four elements per row of `partial`. // Store to four elements per row of `C`.
// No loop is required because vectors are at least 4*32 bits. // No loop is required because vectors are at least 4*32 bits.
const D4 d4; const D4 d4;
sum0 = MaybeLoad<0>(d4, N, buf); sum0 = MaybeLoad<0>(d4, N, buf);
@ -370,7 +370,7 @@ class MMKernel {
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
// Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`.
// `A` and `B` are always BF16, `C` can be F32 or BF16. // `A` and `B` are always BF16, `C` can be F32 or BF16.
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC> template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
static HWY_INLINE void LoopKC(const StridedViewBF A_view, static HWY_INLINE void LoopKC(const StridedViewBF A_view,
@ -966,8 +966,7 @@ class MMState {
const size_t B_stride = const size_t B_stride =
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
// Sequential loop over NC/MC/KC, similar to `loop_nc` below // Similar to `loop_nc` below except for the profiler zone and `MMSetC`.
// except for the profiler strings and `out_tag`.
parallel.ForRangesMC_NC( parallel.ForRangesMC_NC(
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
[&](const IndexRange& range_mc, const IndexRange& range_nc, [&](const IndexRange& range_mc, const IndexRange& range_nc,
@ -990,7 +989,7 @@ class MMState {
} }
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Accumulates into `mc x nc` sections of `C`.
template <typename TB, typename TC, class ParallelT> template <typename TB, typename TC, class ParallelT>
HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A,
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
@ -1001,7 +1000,7 @@ class MMState {
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
// Sequential loop over NC/MC/KC, for when the M/N loops are // Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read // already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`.
const auto loop_nc = [&](const StridedViewBF B_storage_view, const auto loop_nc = [&](const StridedViewBF B_storage_view,
const IndexRange& range_mc, const IndexRange& range_mc,
const IndexRange& range_kc, const IndexRange& range_kc,

View File

@ -332,8 +332,7 @@ class MMStorage {
// Autotuning // Autotuning
// Naming convention: outer loop first, T suffix means threaded. This refers to // Naming convention: outer loop first, T suffix means threaded. This refers to
// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost // the loops *around* `A2C0`, which contains loops over mc/kc.
// `ranges_np` loop across packages is implicit and applies to all of these.
// //
// Parallelizing across K (A/B columns) is undesirable because the resulting // Parallelizing across K (A/B columns) is undesirable because the resulting
// partial dot products require synchronization or reduction across threads. // partial dot products require synchronization or reduction across threads.
@ -341,18 +340,18 @@ enum class MMOrder : uint8_t {
// Single M, parallel N, sequential K (inside the parallel section to // Single M, parallel N, sequential K (inside the parallel section to
// reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K. // reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K.
kNT_K, kNT_K,
// Specialization of `kNT_K` for a single K task with `kDirect`. // Specialization of `kNT_K` for a single K task with `MMSetC`.
kNT, kNT,
// Parallelize over blocks of M and N: good when both are large. We no longer // Parallelize over blocks of M and N: good when both are large. We no longer
// support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as // support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as
// fast on Zen4. // fast on Zen4.
kNT_MT_K, kNT_MT_K,
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`. kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`.
// Resident C (`kK_M_NT`) should be good for large K relative to M and N. // Resident C (`kK_M_NT`) should be good for large K relative to M and N.
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are // However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
// no kN* because we expect M (batch size) to be small relative to K and N. // no kM* because we expect M (batch size) to be small relative to K and N.
}; };
static inline bool IsBlock(MMOrder order) { static inline bool IsBlock(MMOrder order) {