Add support for arbitrary output row pointers

Useful for writing directly to KV cache.

PiperOrigin-RevId: 765615147
This commit is contained in:
Jan Wassenberg 2025-05-31 10:55:12 -07:00 committed by Copybara-Service
parent 9c3e089b09
commit 0023ff8770
2 changed files with 102 additions and 65 deletions

View File

@ -80,6 +80,19 @@ hn::Vec<DC> TCFromF32(DC dc, hn::Vec<DF> vf) {
return hn::DemoteTo(dc, vf); return hn::DemoteTo(dc, vf);
} }
template <typename TC>
class CRows {
public:
CRows(uint8_t** C_rows) : C_rows_(C_rows) {}
TC* HWY_RESTRICT operator[](size_t row_idx) const {
return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]);
}
private:
uint8_t** C_rows_;
};
// Tag classes, passed to `MMKernel::A2C0` to choose between writing one // Tag classes, passed to `MMKernel::A2C0` to choose between writing one
// (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the // (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the
// first kc result to partial, or accumulating the next kc result into partial // first kc result to partial, or accumulating the next kc result into partial
@ -110,7 +123,7 @@ class MMStoreHorizontalSumsIntoC {
VF C20, VF C21, VF C22, VF C23, // VF C20, VF C21, VF C22, VF C23, //
VF C30, VF C31, VF C32, VF C33, // VF C30, VF C31, VF C32, VF C33, //
const size_t row_c, const size_t col_c, const size_t row_c, const size_t col_c,
const MMArgs& args, const RowPtr<TC>& C) const { const MMArgs& args, CRows<TC> C_rows) const {
HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; HWY_ALIGN float buf[16 * hn::MaxLanes(df)];
const size_t N = hn::Lanes(df); const size_t N = hn::Lanes(df);
// Horizontal reductions (`ReduceSum`) are rather expensive, entailing // Horizontal reductions (`ReduceSum`) are rather expensive, entailing
@ -146,10 +159,10 @@ class MMStoreHorizontalSumsIntoC {
if constexpr (kAdd) { if constexpr (kAdd) {
vadd = hn::Load(d4, args.add + col_c); vadd = hn::Load(d4, args.add + col_c);
} }
MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c); MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c);
MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c); MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c);
MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c); MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c);
MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c); MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c);
} }
private: private:
@ -185,13 +198,14 @@ class MMStoreHorizontalSumsIntoC {
} }
} }
template <size_t kRow, typename TC, class DF4, class VF4 = hn::Vec<DF4>> template <size_t kRow, /*deduced:*/ class DF4, class VF4 = hn::Vec<DF4>,
typename TC>
static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale,
VF4 vadd, const RowPtr<TC>& C, VF4 vadd, CRows<TC> C_rows,
const size_t row_c, const size_t row_c,
const size_t col_c) { const size_t col_c) {
if constexpr (kRow < kRowsAC) { if constexpr (kRow < kRowsAC) {
TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c; TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c;
const hn::Rebind<TC, DF4> dc4; const hn::Rebind<TC, DF4> dc4;
const VF4 out = hn::MulAdd(sum, vscale, vadd); const VF4 out = hn::MulAdd(sum, vscale, vadd);
hn::Store(TCFromF32(dc4, out), dc4, pos); hn::Store(TCFromF32(dc4, out), dc4, pos);
@ -359,7 +373,7 @@ class MMKernel {
static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view, static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t mr, const IndexRange& range_mc, size_t mr, const IndexRange& range_mc,
const size_t row_b, size_t kc, Tag tag, const size_t row_b, size_t kc, Tag tag,
const MMArgs& args, const RowPtr<TC>& C) { const MMArgs& args, CRows<TC> C_rows) {
HWY_DASSERT(1 <= mr && mr <= kMaxMR); HWY_DASSERT(1 <= mr && mr <= kMaxMR);
const size_t row0 = range_mc.begin(); const size_t row0 = range_mc.begin();
const size_t mc = range_mc.Num(); const size_t mc = range_mc.Num();
@ -368,7 +382,8 @@ class MMKernel {
// M == 1, or x86 with 8 SIMD registers: // M == 1, or x86 with 8 SIMD registers:
if (HWY_UNLIKELY(mr == 1)) { if (HWY_UNLIKELY(mr == 1)) {
for (; imc < mc; ++imc) { for (; imc < mc; ++imc) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
} }
return; return;
} }
@ -377,11 +392,13 @@ class MMKernel {
if (HWY_UNLIKELY(mr == 2)) { if (HWY_UNLIKELY(mr == 2)) {
if (HWY_LIKELY(mc >= 2)) { if (HWY_LIKELY(mc >= 2)) {
for (; imc <= mc - 2; imc += 2) { for (; imc <= mc - 2; imc += 2) {
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
} }
} }
if (HWY_UNLIKELY(imc != mc)) { if (HWY_UNLIKELY(imc != mc)) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
} }
return; return;
} }
@ -389,17 +406,18 @@ class MMKernel {
HWY_DASSERT(mr == 4); HWY_DASSERT(mr == 4);
if (HWY_LIKELY(mc >= 4)) { if (HWY_LIKELY(mc >= 4)) {
for (; imc <= mc - 4; imc += 4) { for (; imc <= mc - 4; imc += 4) {
LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args,
C_rows);
} }
} }
const size_t remainder_mc = mc - imc; const size_t remainder_mc = mc - imc;
HWY_DASSERT(remainder_mc < 4); HWY_DASSERT(remainder_mc < 4);
if (HWY_UNLIKELY(remainder_mc & 2)) { if (HWY_UNLIKELY(remainder_mc & 2)) {
LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
imc += 2; imc += 2;
} }
if (HWY_UNLIKELY(remainder_mc & 1)) { if (HWY_UNLIKELY(remainder_mc & 1)) {
LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows);
imc += 1; imc += 1;
} }
HWY_DASSERT(imc == mc); HWY_DASSERT(imc == mc);
@ -496,11 +514,11 @@ class MMKernel {
// with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be
// BF16 so we can load directly without `Decompress2`, which is expensive for // BF16 so we can load directly without `Decompress2`, which is expensive for
// NUQ and requires 2x unrolling, which requires more loads. // NUQ and requires 2x unrolling, which requires more loads.
template <size_t kRowsAC, class Tag, typename TC> template <size_t kRowsAC, /*deduced:*/ class Tag, typename TC>
static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view, static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view,
size_t row_ac, size_t imc, size_t col_c, size_t row_ac, size_t imc, size_t col_c,
size_t kc, Tag tag, const MMArgs& args, size_t kc, Tag tag, const MMArgs& args,
const RowPtr<TC>& C) { CRows<TC> C_rows) {
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
using VBF = hn::Vec<decltype(dbf)>; using VBF = hn::Vec<decltype(dbf)>;
const size_t NBF = hn::Lanes(dbf); const size_t NBF = hn::Lanes(dbf);
@ -614,11 +632,11 @@ class MMKernel {
if (args.add) { if (args.add) {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()( MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/true>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args, C); C31, C32, C33, row_ac, col_c, args, C_rows);
} else { } else {
MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()( MMStoreHorizontalSumsIntoC<kRowsAC, /*kAdd=*/false>()(
df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30,
C31, C32, C33, row_ac, col_c, args, C); C31, C32, C33, row_ac, col_c, args, C_rows);
} }
} else { } else {
MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()( MMAddHorizontalSumsIntoPartial<kRowsAC, Tag>()(
@ -642,27 +660,27 @@ class MMScaleDemoteAdd {
template <typename TC> template <typename TC>
static HWY_INLINE void FillC(const IndexRange& range_mc, static HWY_INLINE void FillC(const IndexRange& range_mc,
const IndexRange& range_nc, const MMArgs& args, const IndexRange& range_nc, const MMArgs& args,
const RowPtr<TC>& C) { CRows<TC> C_rows) {
size_t row_c = range_mc.begin(); size_t row_c = range_mc.begin();
if (args.add) { if (args.add) {
constexpr bool kAdd = true; constexpr bool kAdd = true;
if (range_mc.Num() >= 4) { if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args, C); Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
} }
} }
for (; row_c < range_mc.end(); ++row_c) { for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args, C); Do1Row<kAdd>(row_c, range_nc, args, C_rows);
} }
} else { } else {
constexpr bool kAdd = false; constexpr bool kAdd = false;
if (range_mc.Num() >= 4) { if (range_mc.Num() >= 4) {
for (; row_c <= range_mc.end() - 4; row_c += 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) {
Do4Rows<kAdd>(row_c, range_nc, args, C); Do4Rows<kAdd>(row_c, range_nc, args, C_rows);
} }
} }
for (; row_c < range_mc.end(); ++row_c) { for (; row_c < range_mc.end(); ++row_c) {
Do1Row<kAdd>(row_c, range_nc, args, C); Do1Row<kAdd>(row_c, range_nc, args, C_rows);
} }
} }
} }
@ -671,7 +689,7 @@ class MMScaleDemoteAdd {
// Unrolled for 4 rows to reduce the number of loads from `add`. // Unrolled for 4 rows to reduce the number of loads from `add`.
template <bool kAdd, typename TC> template <bool kAdd, typename TC>
static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc,
const MMArgs& args, const RowPtr<TC>& C) { const MMArgs& args, CRows<TC> C_rows) {
const hn::ScalableTag<double> dd; const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc; const hn::Rebind<TC, decltype(dd)> dc;
@ -685,10 +703,10 @@ class MMScaleDemoteAdd {
const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2);
const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3);
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0); TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
TC* HWY_RESTRICT cr1 = C.Row(row_c + 1); TC* HWY_RESTRICT cr1 = C_rows[row_c + 1];
TC* HWY_RESTRICT cr2 = C.Row(row_c + 2); TC* HWY_RESTRICT cr2 = C_rows[row_c + 2];
TC* HWY_RESTRICT cr3 = C.Row(row_c + 3); TC* HWY_RESTRICT cr3 = C_rows[row_c + 3];
// We manually unroll 2x for higher IPC in batch=1. // We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin(); size_t col_c = range_nc.begin();
@ -789,7 +807,7 @@ class MMScaleDemoteAdd {
// Same as above but handles a single row (for remainder rows). // Same as above but handles a single row (for remainder rows).
template <bool kAdd, typename TC> template <bool kAdd, typename TC>
static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc,
const MMArgs& args, const RowPtr<TC>& C) { const MMArgs& args, CRows<TC> C_rows) {
const hn::ScalableTag<double> dd; const hn::ScalableTag<double> dd;
const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo const hn::Rebind<float, decltype(dd)> df; // result of DemoteTo
const hn::Rebind<TC, decltype(dd)> dc; const hn::Rebind<TC, decltype(dd)> dc;
@ -798,7 +816,7 @@ class MMScaleDemoteAdd {
const size_t ND = hn::Lanes(dd); const size_t ND = hn::Lanes(dd);
const VD vscale = hn::Set(dd, args.scale); const VD vscale = hn::Set(dd, args.scale);
const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0);
TC* HWY_RESTRICT cr0 = C.Row(row_c + 0); TC* HWY_RESTRICT cr0 = C_rows[row_c + 0];
// We manually unroll 2x for higher IPC in batch=1. // We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin(); size_t col_c = range_nc.begin();
@ -867,8 +885,8 @@ class MMPerPackage {
A_(args_.env->storage.A(pkg_idx, A.Extents())), A_(args_.env->storage.A(pkg_idx, A.Extents())),
range_np_(range_np), range_np_(range_np),
mr_(config.MR()), mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.Extents().rows)), ranges_mc_(config.RangesOfMC(A.Rows())),
ranges_kc_(config.RangesOfKC(A.Extents().cols)), ranges_kc_(config.RangesOfKC(A.Cols())),
ranges_nc_(config.RangesOfNC(range_np)), ranges_nc_(config.RangesOfNC(range_np)),
order_(config.Order()), order_(config.Order()),
inner_tasks_(config.InnerTasks()), inner_tasks_(config.InnerTasks()),
@ -882,17 +900,16 @@ class MMPerPackage {
// B is decompressed several call layers lower, but not all member functions // B is decompressed several call layers lower, but not all member functions
// depend on TB, so pass it as an argument instead of templating the class. // depend on TB, so pass it as an argument instead of templating the class.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_NOINLINE void operator()(const MatPtrT<TB>& B, HWY_NOINLINE void operator()(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
const RowPtr<TC>& C) const {
switch (order_) { switch (order_) {
case MMOrder::kNT: case MMOrder::kNT:
return DoNT(B, C); return DoNT(B, C_rows);
case MMOrder::kNT_K: case MMOrder::kNT_K:
return DoNT_K(B, C); return DoNT_K(B, C_rows);
case MMOrder::kNT_MT: case MMOrder::kNT_MT:
return DoNT_MT(B, C); return DoNT_MT(B, C_rows);
case MMOrder::kNT_MT_K: case MMOrder::kNT_MT_K:
return DoNT_MT_K(B, C); return DoNT_MT_K(B, C_rows);
default: default:
HWY_UNREACHABLE; HWY_UNREACHABLE;
} }
@ -913,7 +930,7 @@ class MMPerPackage {
// Single M and K, parallel N. Fills all of C directly. // Single M and K, parallel N. Fills all of C directly.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const { HWY_INLINE void DoNT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT", args_); zone.MaybeEnter("MM.NT", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -940,7 +957,7 @@ class MMPerPackage {
DecompressB(B, row_b, range_K, B_view); DecompressB(B, row_b, range_K, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(),
args_, C); args_, C_rows);
} }
}); });
@ -949,7 +966,7 @@ class MMPerPackage {
// Single M, parallel N, sequential K. Fills all of partial. // Single M, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const { HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_K", args_); zone.MaybeEnter("MM.NT_K", args_);
HWY_DASSERT(ranges_mc_.NumTasks() == 1); HWY_DASSERT(ranges_mc_.NumTasks() == 1);
@ -975,7 +992,7 @@ class MMPerPackage {
DecompressB(B, row_b, range_kc, B_view); DecompressB(B, row_b, range_kc, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C); C_rows);
} }
}; };
@ -997,13 +1014,13 @@ class MMPerPackage {
MMZone fill_zone; MMZone fill_zone;
if (out_ == MMOut::kCopy) { if (out_ == MMOut::kCopy) {
fill_zone.MaybeEnter("MM.NT_K.FillC", args_); fill_zone.MaybeEnter("MM.NT_K.FillC", args_);
MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C); MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows);
} else if (out_ == MMOut::kParM) { } else if (out_ == MMOut::kParM) {
fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_); fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_);
args_.env->parallel.ForRangeMC( args_.env->parallel.ForRangeMC(
range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR { range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR {
MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_,
args_, C); args_, C_rows);
}); });
} else { } else {
HWY_UNREACHABLE; // kDirect is only used with kNT. HWY_UNREACHABLE; // kDirect is only used with kNT.
@ -1013,7 +1030,7 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, single K. // Parallel loops over mc/nc blocks of M/range_np, single K.
// Fills `mc x nc` sections of C directly, in parallel. // Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, const RowPtr<TC>& C) const { HWY_INLINE void DoNT_MT(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT", args_); zone.MaybeEnter("MM.NT_MT", args_);
HWY_DASSERT(ranges_kc_.NumTasks() == 1); HWY_DASSERT(ranges_kc_.NumTasks() == 1);
@ -1039,7 +1056,7 @@ class MMPerPackage {
DecompressB(B, row_b, range_K, B_view); DecompressB(B, row_b, range_K, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(),
args_, C); args_, C_rows);
} }
}); });
@ -1049,7 +1066,7 @@ class MMPerPackage {
// Parallel loops over mc/nc blocks of M/range_np, sequential K. // Parallel loops over mc/nc blocks of M/range_np, sequential K.
// Fills `mc x nc` sections of `partial`, then `C`, in parallel. // Fills `mc x nc` sections of `partial`, then `C`, in parallel.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, const RowPtr<TC>& C) const { HWY_INLINE void DoNT_MT_K(const MatPtrT<TB>& B, CRows<TC> C_rows) const {
MMZone zone; MMZone zone;
zone.MaybeEnter("MM.NT_MT_K", args_); zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize(); const size_t kc_max = ranges_kc_.TaskSize();
@ -1074,7 +1091,7 @@ class MMPerPackage {
DecompressB(B, row_b, range_kc, B_view); DecompressB(B, row_b, range_kc, B_view);
} }
MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_,
C); C_rows);
} }
}; // loop_nc }; // loop_nc
args_.env->parallel.ForRangesMC_NC( args_.env->parallel.ForRangesMC_NC(
@ -1097,7 +1114,7 @@ class MMPerPackage {
HWY_DASSERT(out_ == MMOut::kCopy); HWY_DASSERT(out_ == MMOut::kCopy);
MMZone fill_zone; MMZone fill_zone;
fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_); fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_);
MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C); MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows);
}); });
} }
@ -1258,7 +1275,7 @@ struct MMImpl {
// or with the best config. // or with the best config.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const RowPtr<TC>& C, const MMArgs& args, CRows<TC> C_rows, const MMArgs& args,
const MMConfig& config) { const MMConfig& config) {
MMZone matmul_zone; MMZone matmul_zone;
matmul_zone.MaybeEnter("MM.DoMatMul", args); matmul_zone.MaybeEnter("MM.DoMatMul", args);
@ -1267,7 +1284,7 @@ struct MMImpl {
args.env->parallel.ForPkg( args.env->parallel.ForPkg(
args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) {
const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx);
MMPerPackage(A, args, config, pkg_idx, range_np)(B, C); MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows);
}); });
} }
}; };
@ -1275,7 +1292,7 @@ struct MMImpl {
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
// //
// `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's // `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's
// `K = B.Extents().cols`, which must match `A.Extents().cols`, is the number // `K = B.Cols()`, which must match `A.Cols()`, is the number
// of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There
// are no other restrictions on shape, though performance is better when `M % 4 // are no other restrictions on shape, though performance is better when `M % 4
// == 0` or `M <= 4`. // == 0` or `M <= 4`.
@ -1295,11 +1312,11 @@ struct MMImpl {
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) { CRows<TC> C_rows) {
const Allocator& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Extents().rows; const size_t M = A.Rows();
const size_t K = A.Extents().cols; const size_t K = A.Cols();
const size_t N = B.Extents().rows; const size_t N = B.Rows();
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
intptr_t index = MMImpl::IndexOfKey(key, env.keys); intptr_t index = MMImpl::IndexOfKey(key, env.keys);
// First time we see this shape/key. // First time we see this shape/key.
@ -1323,7 +1340,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(), const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add, env.storage.Partial()); add, env.storage.Partial());
if (HWY_LIKELY(tuner.Best())) { if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best());
return &per_key; return &per_key;
} }
@ -1332,8 +1349,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// First call: enumerate all feasible configs. // First call: enumerate all feasible configs.
if (HWY_UNLIKELY(!tuner.HasCandidates())) { if (HWY_UNLIKELY(!tuner.HasCandidates())) {
// Ensure matrix dimensions match each other. // Ensure matrix dimensions match each other.
HWY_ASSERT(K == B.Extents().cols); HWY_ASSERT(K == B.Cols());
HWY_ASSERT(N == C.Cols());
HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(M <= MMStorage::kMaxM);
HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(K <= MMStorage::kMaxK);
HWY_ASSERT(N <= MMStorage::kMaxN); HWY_ASSERT(N <= MMStorage::kMaxN);
@ -1347,7 +1363,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const MMConfig& cfg = tuner.NextConfig(); const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start(); const uint64_t t0 = hwy::timer::Start();
MMImpl::DoMatMul(A, B, C, args, cfg); MMImpl::DoMatMul(A, B, C_rows, args, cfg);
const uint64_t t1 = const uint64_t t1 =
env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start();
const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) / const double min_elapsed = static_cast<double>(tuner.NotifyTicks(t1 - t0)) /
@ -1376,6 +1392,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
return &per_key; return &per_key;
} }
// Adapter that fills the row array. This is the common case, whereas only
// GemmaAttention::ComputeQKV uses the arbitrary output rows feature.
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
HWY_DASSERT(B.Rows() == C.Cols());
for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) {
env.storage.OutRow(row_ac) = reinterpret_cast<uint8_t*>(C.Row(row_ac));
}
return MatMul(A, B, add, env, CRows<TC>(&env.storage.OutRow(0)));
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp

View File

@ -193,10 +193,11 @@ class MMStorage {
// Internally threaded; must not be called concurrently with the same // Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`). // `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel) MMStorage(const Allocator& allocator, MMParallel& parallel)
: out_rows(hwy::AllocateAligned<uint8_t*>(kMaxM)),
// Per-worker copies of `partial` would be wasteful. We instead allocate // Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at // one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity. // false-sharing-free granularity.
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
MatPadding::kOdd), MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind. // Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) {
@ -220,6 +221,8 @@ class MMStorage {
BindC(partial_storage_, parallel); BindC(partial_storage_, parallel);
} }
uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; }
// Returns per-package matrix view. // Returns per-package matrix view.
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const { RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.rows <= kMaxM);
@ -231,6 +234,11 @@ class MMStorage {
RowPtrD Partial() const { return partial_; } RowPtrD Partial() const { return partial_; }
private: private:
// Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a
// constant stride, but GemmaAttention::ComputeQKV writes to differing KV
// positions per query / output row. `kMaxM` elements are too large for the
// stack, hence dynamic allocation.
hwy::AlignedFreeUniquePtr<uint8_t*[]> out_rows;
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages]; std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
MatStorageT<double> partial_storage_; MatStorageT<double> partial_storage_;
RowPtrD partial_; RowPtrD partial_;