From 34ceee6c308a362ce8502d19580baa65256f4b34 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 05:56:57 -0700 Subject: [PATCH] Update MatMul comments, removing mention of partial. PiperOrigin-RevId: 804872289 --- ops/matmul-inl.h | 11 +++++------ ops/matmul.h | 9 ++++----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 74deb78..f2e9c49 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -132,7 +132,7 @@ class MMStoreHorizontalSumsIntoC { // Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in // the elements of one V4. We have four independent rows `r`, hence the // code is effectively unrolled, which increases throughput. - // Store to four elements per row of `partial`. + // Store to four elements per row of `C`. // No loop is required because vectors are at least 4*32 bits. const D4 d4; sum0 = MaybeLoad<0>(d4, N, buf); @@ -370,7 +370,7 @@ class MMKernel { // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). - // Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`. + // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. // `A` and `B` are always BF16, `C` can be F32 or BF16. template static HWY_INLINE void LoopKC(const StridedViewBF A_view, @@ -966,8 +966,7 @@ class MMState { const size_t B_stride = Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); - // Sequential loop over NC/MC/KC, similar to `loop_nc` below - // except for the profiler strings and `out_tag`. + // Similar to `loop_nc` below except for the profiler zone and `MMSetC`. parallel.ForRangesMC_NC( args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, @@ -990,7 +989,7 @@ class MMState { } // Parallel loops over mc/nc blocks of M/range_np, sequential K. - // Fills `mc x nc` sections of `partial`, then `C`, in parallel. + // Accumulates into `mc x nc` sections of `C`. template HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { @@ -1001,7 +1000,7 @@ class MMState { Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read - // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. + // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`. const auto loop_nc = [&](const StridedViewBF B_storage_view, const IndexRange& range_mc, const IndexRange& range_kc, diff --git a/ops/matmul.h b/ops/matmul.h index c86ecc3..946673a 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -332,8 +332,7 @@ class MMStorage { // Autotuning // Naming convention: outer loop first, T suffix means threaded. This refers to -// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost -// `ranges_np` loop across packages is implicit and applies to all of these. +// the loops *around* `A2C0`, which contains loops over mc/kc. // // Parallelizing across K (A/B columns) is undesirable because the resulting // partial dot products require synchronization or reduction across threads. @@ -341,18 +340,18 @@ enum class MMOrder : uint8_t { // Single M, parallel N, sequential K (inside the parallel section to // reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K. kNT_K, - // Specialization of `kNT_K` for a single K task with `kDirect`. + // Specialization of `kNT_K` for a single K task with `MMSetC`. kNT, // Parallelize over blocks of M and N: good when both are large. We no longer // support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as // fast on Zen4. kNT_MT_K, - kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`. + kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`. // Resident C (`kK_M_NT`) should be good for large K relative to M and N. // However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are - // no kN* because we expect M (batch size) to be small relative to K and N. + // no kM* because we expect M (batch size) to be small relative to K and N. }; static inline bool IsBlock(MMOrder order) {