Refactor MatMul to accept views in the kernel functions

Make arg order consistent.
Move StridedView into mat.h.
Add view support to RowPtrs.

PiperOrigin-RevId: 805197381
This commit is contained in:
Jan Wassenberg 2025-09-09 22:09:09 -07:00 committed by Copybara-Service
parent f10ac41a20
commit 9457258330
3 changed files with 192 additions and 146 deletions

View File

@ -148,21 +148,21 @@ class MMStoreHorizontalSumsIntoC {
} }
} }
// Scales the dot-product terms and adds bias (if present) and stores the // Scales the dot-product terms plus `add` (if non-null) and stores the four
// four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is // 4-wide vectors to `C` starting at row 0, column 0. If `tag` is `MMSetC`,
// `MMSetC`, the vectors are written as-is (first call, or small K). // the vectors are written as-is (first call, or small K). Otherwise, they
// Otherwise, they are partial sums and are accumulated into C. // are partial sums and are accumulated into C.
template <class D4, class V4 = hn::Vec<D4>, class Tag, class CRows> template <class D4, class V4 = hn::Vec<D4>, class Tag, class CView>
HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3,
const size_t row_c, const size_t col_c, const float scale, const float* HWY_RESTRICT add,
const MMArgs& args, CRows C_rows) const { const size_t imc, Tag tag, CView C_rows) const {
const V4 vscale = hn::Set(d4, args.scale); const V4 vscale = hn::Set(d4, scale);
HWY_ALIGN static constexpr float kZero[4] = {}; HWY_ALIGN static constexpr float kZero[4] = {};
const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); const V4 vadd = hn::Load(d4, add ? add : kZero);
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c); MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c); MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c); MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c); MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows);
} }
private: private:
@ -199,13 +199,13 @@ class MMStoreHorizontalSumsIntoC {
} }
template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>, template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
class Tag, typename TC> class Tag, class CView>
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
VF4 vadd, Tag, RowPtrs<TC> C_rows, VF4 vadd, Tag, const size_t imc,
const size_t row_c, CView C_view) {
const size_t col_c) {
if constexpr (kRow < kRowsAC) { if constexpr (kRow < kRowsAC) {
TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; using TC = hwy::RemoveCvRef<decltype(C_view.Row(0)[0])>;
TC* HWY_RESTRICT pos = C_view.Row(imc + kRow);
const hn::Rebind<TC, DF4> dc4; const hn::Rebind<TC, DF4> dc4;
if constexpr (hwy::IsSame<Tag, MMAddC>()) { if constexpr (hwy::IsSame<Tag, MMAddC>()) {
vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value
@ -234,7 +234,7 @@ class MMDecompress {
// Neither A nor B require padding because `LoopKC` handles remainders. // Neither A nor B require padding because `LoopKC` handles remainders.
if constexpr (hwy::IsSame<TB, BF16>()) { if constexpr (hwy::IsSame<TB, BF16>()) {
return View(B, row_b, range_kc.begin(), range_kc.Num()); return StridedViewBF(B, row_b, range_kc.begin(), range_kc.Num());
} }
const PackedSpan<const TB> B_span = B.PaddedSpan(); const PackedSpan<const TB> B_span = B.PaddedSpan();
@ -264,7 +264,7 @@ class MMDecompress {
if constexpr (IsBF16<TA>()) { if constexpr (IsBF16<TA>()) {
// We can use a view, regardless of columns/padding, because // We can use a view, regardless of columns/padding, because
// `MMKernel::LoopKC` supports non-vector multiples. // `MMKernel::LoopKC` supports non-vector multiples.
return View(A, 0, 0, A.Cols()); return StridedViewBF(A, 0, 0, A.Cols());
} else { } else {
// Always decompress. To reduce code size/compile time, we no longer // Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16. We also only // support a separate F32 kernel; most A are already BF16. We also only
@ -277,15 +277,6 @@ class MMDecompress {
} }
private: private:
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
// Decompresses all `M x K` from `A` into padded BF16 `A_view`. // Decompresses all `M x K` from `A` into padded BF16 `A_view`.
static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A, static HWY_NOINLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view, const StridedViewBF A_view,
@ -402,26 +393,26 @@ class MMKernel {
kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16);
public: public:
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is
// `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start
// at row/col 0. `CView` is either `RowPtrs<TC>` or `StridedView<TC>`.
// Called by B3A2C0 and by callers that hoist `A_view`. // Called by B3A2C0 and by callers that hoist `A_view`.
template <class Tag, class CRows> template <class Tag, class CView>
static HWY_INLINE void A2C0(const StridedViewBF A_view, static HWY_INLINE void A2C0(const StridedViewBF A_view,
const StridedViewBF B_view, size_t mr, const StridedViewBF B_view, size_t mr,
const IndexRange& range_mc, const size_t row_b, const IndexRange& range_mc, size_t kc,
size_t kc, Tag tag, const MMArgs& args, const float scale, const float* HWY_RESTRICT add,
CRows C_rows) { Tag tag, CView C_view) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR); HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t row0 = range_mc.begin();
const size_t mc = range_mc.Num(); const size_t mc = range_mc.Num();
size_t imc = 0; size_t imc = 0;
// M == 1, or x86 with 8 SIMD registers: // M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) { if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) { for (; imc < mc; ++imc) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
C_rows);
} }
return; return;
} }
@ -430,13 +421,11 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) { if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) { if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) { for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
C_rows);
} }
} }
if (HWY_UNLIKELY(imc != mc)) { if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
C_rows);
} }
return; return;
} }
@ -444,18 +433,17 @@ class MMKernel {
HWY_DASSERT(mr == 4); HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) { if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) { for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view);
C_rows);
} }
} }
const size_t remainder_mc = mc - imc; const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4); HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) { if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view);
imc += 2; imc += 2;
} }
if (HWY_UNLIKELY(remainder_mc & 1)) { if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view);
imc += 1; imc += 1;
} }
HWY_DASSERT(imc == mc); HWY_DASSERT(imc == mc);
@ -466,11 +454,11 @@ class MMKernel {
// Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads
// `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by
// `ForeachKC` and when there is only a single KC task. // `ForeachKC` and when there is only a single KC task.
template <typename TB, typename Tag, class CRows> template <typename TB, typename TC, typename Tag>
static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B, static void B3A2C0(const StridedViewBF A, const MatPtrT<TB>& B,
const MMArgs& args, const IndexRange& range_mc, const IndexRange& range_mc, const IndexRange& range_kc,
const IndexRange& range_kc, const IndexRange& range_nc, const IndexRange& range_nc, const MMArgs& args,
size_t mr, Tag out_tag, CRows C_rows) { Tag out_tag, RowPtrs<TC> C) {
HWY_ALIGN BF16 B_storage[B_storage_max]; HWY_ALIGN BF16 B_storage[B_storage_max];
const size_t kc = range_kc.Num(); const size_t kc = range_kc.Num();
@ -482,24 +470,28 @@ class MMKernel {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
StridedViewBF B_view = const StridedViewBF B_view =
MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view);
A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); const RowPtrs<TC> C_view = C.View(range_mc.begin(), row_b);
const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr;
A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag,
C_view);
} }
} }
template <typename TB, class CRows> template <typename TB, typename TC>
static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B, static void ForeachKC(const StridedViewBF A, const MatPtrT<TB>& B,
const MMArgs& args, const IndexRange& range_mc, const IndexRange& range_mc,
const IndexRangePartition& ranges_kc, const IndexRangePartition& ranges_kc,
const IndexRange& range_nc, size_t mr, CRows C_rows) { const IndexRange& range_nc, const MMArgs& args,
RowPtrs<TC> C) {
// Peel off the first iteration of the kc loop: avoid zero-initializing `C` // Peel off the first iteration of the kc loop: avoid zero-initializing `C`
// by writing directly into it, and later accumulating into it. // by writing directly into it, and later accumulating into it.
ranges_kc.VisitFirst([&](const IndexRange& range_kc) { ranges_kc.VisitFirst([&](const IndexRange& range_kc) {
B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows); B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C);
}); });
ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { ranges_kc.VisitRemaining([&](const IndexRange& range_kc) {
B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows); B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C);
}); });
} }
@ -593,19 +585,20 @@ class MMKernel {
// Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a
// multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view`
// from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0).
// Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. // Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0.
// `A` and `B` are always BF16, `C` can be F32 or BF16. // `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also
template <size_t kRowsAC, /*deduced:*/ class Tag, class CRows> // relative to the C column.
template <size_t kRowsAC, /*deduced:*/ class Tag, class CView>
static HWY_INLINE void LoopKC(const StridedViewBF A_view, static HWY_INLINE void LoopKC(const StridedViewBF A_view,
const StridedViewBF B_view, size_t row_ac, const StridedViewBF B_view, size_t imc,
size_t imc, size_t col_c, size_t kc, Tag tag, size_t kc, const float scale,
const MMArgs& args, CRows C_rows) { const float* HWY_RESTRICT add, Tag tag,
CView C_view) {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>; using VBF = hn::Vec<decltype(dbf)>;
HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf);
HWY_DASSERT(kRowsAC <= kMaxMR); HWY_DASSERT(kRowsAC <= kMaxMR);
HWY_DASSERT(col_c % kNR == 0);
// Rows are aligned to `kMaxMR`, except for the last tile of A. // Rows are aligned to `kMaxMR`, except for the last tile of A.
// `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B.
@ -784,7 +777,7 @@ class MMKernel {
hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3; hn::Vec<decltype(d4)> sum0, sum1, sum2, sum3;
horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22,
C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3);
horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows); horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view);
} }
}; };
@ -884,7 +877,7 @@ class MMLoops {
// or with the best config. // or with the best config.
template <typename TB, typename TC> template <typename TB, typename TC>
static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B, static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args) { RowPtrs<TC> C, const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch");
PROFILER_ZONE3(args.env.ctx.profiler, PROFILER_ZONE3(args.env.ctx.profiler,
args.env.ctx.Worker(args.options.cluster_idx), zone); args.env.ctx.Worker(args.options.cluster_idx), zone);
@ -892,7 +885,7 @@ class MMLoops {
DispatchParallelism( DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) HWY_ATTR { args.options.parallelism, [&](const auto& parallel) HWY_ATTR {
DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { DispatchOrder(args.order, [&](const auto& order) HWY_ATTR {
Loop(order, parallel, A, B, C_rows, args); Loop(order, parallel, A, B, C, args);
}); });
}); });
} }
@ -904,11 +897,11 @@ class MMLoops {
return HWY_MAX(kNR, line_bytes / sizeof_TC); return HWY_MAX(kNR, line_bytes / sizeof_TC);
} }
// Single M and K ranges, parallel N. Fills all of C directly. // Single M and K ranges, parallel N.
template <typename TB, typename TC, class Parallel> template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, static HWY_INLINE void Loop(MMOrderNT, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args) { RowPtrs<TC> C, const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); static const auto zone = args.env.ctx.profiler.AddZone("MM.NT");
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
@ -932,10 +925,21 @@ class MMLoops {
for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) { row_b += kNR) {
StridedViewBF B_view = const StridedViewBF B_view =
MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); MMDecompress::DecompressB(B, row_b, range_K, B_storage_view);
MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), const RowPtrs<TC> C_view = C.View(range_M.begin(), row_b);
args, C_rows); const float* HWY_RESTRICT add =
args.add ? args.add + row_b : nullptr;
MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add,
MMSetC(), C_view);
}
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_M, range_nc, C2, worker);
}
} }
}); });
} }
@ -944,7 +948,7 @@ class MMLoops {
template <typename TB, typename TC, class Parallel> template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args) { RowPtrs<TC> C, const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K");
HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_mc.NumTasks() == 1);
const IndexRange& range_mc = args.ranges_mc.Range(0); const IndexRange& range_mc = args.ranges_mc.Range(0);
@ -955,17 +959,24 @@ class MMLoops {
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc,
range_nc, args.mr, C_rows); range_nc, args, C);
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_mc, range_nc, C2, worker);
}
}
}); });
} }
// Parallel loops over mc/nc blocks of M/range_n, single K. // Parallel loops over mc/nc blocks of M/range_n, single K.
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C.
template <typename TB, typename TC, class Parallel> template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args) { RowPtrs<TC> C, const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT");
HWY_DASSERT(args.ranges_kc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1);
const IndexRange& range_K = args.ranges_kc.Range(0); const IndexRange& range_K = args.ranges_kc.Range(0);
@ -976,17 +987,24 @@ class MMLoops {
size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr, MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(),
MMSetC(), C_rows); C);
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_mc, range_nc, C2, worker);
}
}
}); });
} }
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_n, sequential K.
// Accumulates into `mc x nc` sections of `C`. // Accumulates into `mc x nc` sections of `C`.
template <typename TB, typename TC, class Parallel> template <typename TB, typename TC, class Parallel>
static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel,
const StridedViewBF A, const MatPtrT<TB>& B, const StridedViewBF A, const MatPtrT<TB>& B,
RowPtrs<TC> C_rows, const MMArgs& args) { RowPtrs<TC> C, const MMArgs& args) {
static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K");
parallel.ForRangesMC_NC( parallel.ForRangesMC_NC(
@ -995,8 +1013,15 @@ class MMLoops {
size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
MMZone mm_zone; MMZone mm_zone;
mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune);
MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc, MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args,
args.mr, C_rows); C);
if constexpr (IsBF16<TC>()) {
if (args.options.fused) {
StridedViewBF C2(nullptr, 0, 0);
args.options.fused(C, range_mc, range_nc, C2, worker);
}
}
}); });
} }
}; // MMLoops }; // MMLoops

View File

@ -60,54 +60,6 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink?
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
// Also used to decompress B, hence non-const.
#pragma pack(push, 1) // power of two size
template <typename T>
class StridedView {
public:
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
HWY_DASSERT(stride >= cols);
}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
StridedView<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return StridedView<T>(Row(r) + c, cols, stride_);
}
private:
T* HWY_RESTRICT row0_;
uint32_t cols_;
uint32_t stride_;
};
#pragma pack(pop)
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
using MMFused = std::function<void(StridedViewBF, size_t, size_t)>;
struct MMOptions {
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
MMFused fused;
};
// Policy classes for parallelism, implementing some of `ParallelismStrategy`. // Policy classes for parallelism, implementing some of `ParallelismStrategy`.
struct MMParallelNone { struct MMParallelNone {
@ -735,6 +687,19 @@ struct MatMulEnv {
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs; std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
}; };
// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols)
// that this thread has just filled, a view into a second tile (only for the
// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`).
using MMFused = std::function<void(RowPtrsBF, IndexRange, IndexRange,
StridedViewBF, size_t)>;
struct MMOptions {
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical;
MMFused fused; // called if non-null and `TC` is BF16.
};
// Arguments to MatMul() that are independent of the A/B/C types. Reduces // Arguments to MatMul() that are independent of the A/B/C types. Reduces
// register pressure compared to individual values/references. Also used for // register pressure compared to individual values/references. Also used for
// passing through `DispatchOrder`. // passing through `DispatchOrder`.

View File

@ -38,17 +38,27 @@ namespace gcpp {
template <typename T> template <typename T>
class RowPtrs { class RowPtrs {
public: public:
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {} RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
RowPtrs View(size_t r, size_t c) {
RowPtrs<T> view(row_ptrs_);
view.r0_ = static_cast<uint32_t>(r);
view.c0_ = static_cast<uint32_t>(c);
return view;
}
T* HWY_RESTRICT Row(size_t row_idx) const { T* HWY_RESTRICT Row(size_t row_idx) const {
return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]); return HWY_RCAST_ALIGNED(T*, row_ptrs_[r0_ + row_idx]) + c0_;
} }
T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); }
private: private:
uint8_t** row_ptrs_; uint8_t** row_ptrs_;
uint32_t r0_;
uint32_t c0_;
}; };
using RowPtrsBF = RowPtrs<BF16>;
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector // Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class // or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
// to store hetereogeneous tensor references in a vector. // to store hetereogeneous tensor references in a vector.
@ -349,12 +359,12 @@ RowPtrs<T> GetOrSetTempRowPtrs(
template <class Func, typename... Args> template <class Func, typename... Args>
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) { Args&&... args) {
#if GEMMA_ENABLE_NUQ if constexpr (GEMMA_ENABLE_NUQ) {
if (base->GetType() == Type::kNUQ) { if (base->GetType() == Type::kNUQ) {
const MatPtrT<NuqStream> mat(*base); const MatPtrT<NuqStream> mat(*base);
return func(&mat, std::forward<Args>(args)...); return func(&mat, std::forward<Args>(args)...);
} }
#endif // GEMMA_ENABLE_NUQ }
if (base->GetType() == Type::kF32) { if (base->GetType() == Type::kF32) {
const MatPtrT<float> mat(*base); const MatPtrT<float> mat(*base);
@ -376,13 +386,13 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const Func& func, Args&&... args) { const Func& func, Args&&... args) {
HWY_DASSERT(base1->GetType() == base2->GetType()); HWY_DASSERT(base1->GetType() == base2->GetType());
#if GEMMA_ENABLE_NUQ if constexpr (GEMMA_ENABLE_NUQ) {
if (base1->GetType() == Type::kNUQ) { if (base1->GetType() == Type::kNUQ) {
const MatPtrT<NuqStream> mat1(*base1); const MatPtrT<NuqStream> mat1(*base1);
const MatPtrT<NuqStream> mat2(*base2); const MatPtrT<NuqStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...); return func(&mat1, &mat2, std::forward<Args>(args)...);
} }
#endif // GEMMA_ENABLE_NUQ }
if (base1->GetType() == Type::kF32) { if (base1->GetType() == Type::kF32) {
const MatPtrT<float> mat1(*base1); const MatPtrT<float> mat1(*base1);
@ -508,5 +518,51 @@ class MatFactory {
MatPadding padding_; MatPadding padding_;
}; };
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
// Also used to decompress B, hence non-const.
#pragma pack(push, 1) // power of two size
template <typename T>
class StridedView {
public:
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
cols_(static_cast<uint32_t>(cols)),
stride_(static_cast<uint32_t>(stride)) {
HWY_DASSERT(stride >= cols);
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
StridedView(const MatPtrT<T>& mat, size_t r, size_t c, size_t cols)
: StridedView(const_cast<T*>(mat.Row(r)) + c, cols, mat.Stride()) {
HWY_DASSERT(c < mat.Cols());
HWY_DASSERT(cols <= mat.Cols() - c);
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
StridedView<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return StridedView<T>(Row(r) + c, cols, stride_);
}
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return static_cast<size_t>(stride_); }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
}
private:
T* HWY_RESTRICT row0_;
uint32_t cols_;
uint32_t stride_;
};
#pragma pack(pop)
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_