mirror of https://github.com/google/gemma.cpp.git
Update MatMul comments, removing mention of partial.
PiperOrigin-RevId: 804872289
This commit is contained in:
parent
a5ab99e4ba
commit
34ceee6c30
|
|
@ -132,7 +132,7 @@ class MMStoreHorizontalSumsIntoC {
|
||||||
// Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in
|
// Adding N consecutive V4 yields horizontal sums of Cr0, Cr1, Cr2, Cr3 in
|
||||||
// the elements of one V4. We have four independent rows `r`, hence the
|
// the elements of one V4. We have four independent rows `r`, hence the
|
||||||
// code is effectively unrolled, which increases throughput.
|
// code is effectively unrolled, which increases throughput.
|
||||||
// Store to four elements per row of `partial`.
|
// Store to four elements per row of `C`.
|
||||||
// No loop is required because vectors are at least 4*32 bits.
|
// No loop is required because vectors are at least 4*32 bits.
|
||||||
const D4 d4;
|
const D4 d4;
|
||||||
sum0 = MaybeLoad<0>(d4, N, buf);
|
sum0 = MaybeLoad<0>(d4, N, buf);
|
||||||
|
|
@ -370,7 +370,7 @@ class MMKernel {
|
||||||
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
|
||||||
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
|
||||||
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
|
||||||
// Updates a `kRowsAC x kNR` tile with top-left `partial.Row(row_ac) + col_c`.
|
// Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`.
|
||||||
// `A` and `B` are always BF16, `C` can be F32 or BF16.
|
// `A` and `B` are always BF16, `C` can be F32 or BF16.
|
||||||
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
|
||||||
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
static HWY_INLINE void LoopKC(const StridedViewBF A_view,
|
||||||
|
|
@ -966,8 +966,7 @@ class MMState {
|
||||||
const size_t B_stride =
|
const size_t B_stride =
|
||||||
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
|
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
|
||||||
|
|
||||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
// Similar to `loop_nc` below except for the profiler zone and `MMSetC`.
|
||||||
// except for the profiler strings and `out_tag`.
|
|
||||||
parallel.ForRangesMC_NC(
|
parallel.ForRangesMC_NC(
|
||||||
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
|
args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
[&](const IndexRange& range_mc, const IndexRange& range_nc,
|
||||||
|
|
@ -990,7 +989,7 @@ class MMState {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
// Parallel loops over mc/nc blocks of M/range_np, sequential K.
|
||||||
// Fills `mc x nc` sections of `partial`, then `C`, in parallel.
|
// Accumulates into `mc x nc` sections of `C`.
|
||||||
template <typename TB, typename TC, class ParallelT>
|
template <typename TB, typename TC, class ParallelT>
|
||||||
HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A,
|
HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A,
|
||||||
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||||
|
|
@ -1001,7 +1000,7 @@ class MMState {
|
||||||
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
|
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
|
||||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`.
|
||||||
const auto loop_nc = [&](const StridedViewBF B_storage_view,
|
const auto loop_nc = [&](const StridedViewBF B_storage_view,
|
||||||
const IndexRange& range_mc,
|
const IndexRange& range_mc,
|
||||||
const IndexRange& range_kc,
|
const IndexRange& range_kc,
|
||||||
|
|
|
||||||
|
|
@ -332,8 +332,7 @@ class MMStorage {
|
||||||
// Autotuning
|
// Autotuning
|
||||||
|
|
||||||
// Naming convention: outer loop first, T suffix means threaded. This refers to
|
// Naming convention: outer loop first, T suffix means threaded. This refers to
|
||||||
// the loops *around* `A2C0`, which contains loops over mc/kc. The outermost
|
// the loops *around* `A2C0`, which contains loops over mc/kc.
|
||||||
// `ranges_np` loop across packages is implicit and applies to all of these.
|
|
||||||
//
|
//
|
||||||
// Parallelizing across K (A/B columns) is undesirable because the resulting
|
// Parallelizing across K (A/B columns) is undesirable because the resulting
|
||||||
// partial dot products require synchronization or reduction across threads.
|
// partial dot products require synchronization or reduction across threads.
|
||||||
|
|
@ -341,18 +340,18 @@ enum class MMOrder : uint8_t {
|
||||||
// Single M, parallel N, sequential K (inside the parallel section to
|
// Single M, parallel N, sequential K (inside the parallel section to
|
||||||
// reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K.
|
// reduce fork-joins). Similar to GotoBLAS, good for large N vs. M and K.
|
||||||
kNT_K,
|
kNT_K,
|
||||||
// Specialization of `kNT_K` for a single K task with `kDirect`.
|
// Specialization of `kNT_K` for a single K task with `MMSetC`.
|
||||||
kNT,
|
kNT,
|
||||||
|
|
||||||
// Parallelize over blocks of M and N: good when both are large. We no longer
|
// Parallelize over blocks of M and N: good when both are large. We no longer
|
||||||
// support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as
|
// support `kMT_NT_K`: no advantage on Skylake, and `kNT_MT_K` is 1.5x as
|
||||||
// fast on Zen4.
|
// fast on Zen4.
|
||||||
kNT_MT_K,
|
kNT_MT_K,
|
||||||
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `kDirect`.
|
kNT_MT, // Specialization of `kNT_MT_K` for a single K task with `MMSetC`.
|
||||||
|
|
||||||
// Resident C (`kK_M_NT`) should be good for large K relative to M and N.
|
// Resident C (`kK_M_NT`) should be good for large K relative to M and N.
|
||||||
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
|
// However, it does not (much) outperform `kNT_K` on SKX and Zen4. There are
|
||||||
// no kN* because we expect M (batch size) to be small relative to K and N.
|
// no kM* because we expect M (batch size) to be small relative to K and N.
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool IsBlock(MMOrder order) {
|
static inline bool IsBlock(MMOrder order) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue