From 0023ff8770705caad4427c92caabf4a098e58952 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Sat, 31 May 2025 10:55:12 -0700 Subject: [PATCH] Add support for arbitrary output row pointers Useful for writing directly to KV cache. PiperOrigin-RevId: 765615147 --- ops/matmul-inl.h | 151 ++++++++++++++++++++++++++++------------------- ops/matmul.h | 16 +++-- 2 files changed, 102 insertions(+), 65 deletions(-) diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index fc245d8..ddcddd5 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -80,6 +80,19 @@ hn::Vec TCFromF32(DC dc, hn::Vec vf) { return hn::DemoteTo(dc, vf); } +template +class CRows { + public: + CRows(uint8_t** C_rows) : C_rows_(C_rows) {} + + TC* HWY_RESTRICT operator[](size_t row_idx) const { + return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]); + } + + private: + uint8_t** C_rows_; +}; + // Tag classes, passed to `MMKernel::A2C0` to choose between writing one // (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the // first kc result to partial, or accumulating the next kc result into partial @@ -110,7 +123,7 @@ class MMStoreHorizontalSumsIntoC { VF C20, VF C21, VF C22, VF C23, // VF C30, VF C31, VF C32, VF C33, // const size_t row_c, const size_t col_c, - const MMArgs& args, const RowPtr& C) const { + const MMArgs& args, CRows C_rows) const { HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing @@ -146,10 +159,10 @@ class MMStoreHorizontalSumsIntoC { if constexpr (kAdd) { vadd = hn::Load(d4, args.add + col_c); } - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C, row_c, col_c); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C, row_c, col_c); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C, row_c, col_c); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C, row_c, col_c); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, C_rows, row_c, col_c); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, C_rows, row_c, col_c); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, C_rows, row_c, col_c); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, C_rows, row_c, col_c); } private: @@ -185,13 +198,14 @@ class MMStoreHorizontalSumsIntoC { } } - template > + template , + typename TC> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, - VF4 vadd, const RowPtr& C, + VF4 vadd, CRows C_rows, const size_t row_c, const size_t col_c) { if constexpr (kRow < kRowsAC) { - TC* HWY_RESTRICT pos = C.Row(row_c + kRow) + col_c; + TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; const hn::Rebind dc4; const VF4 out = hn::MulAdd(sum, vscale, vadd); hn::Store(TCFromF32(dc4, out), dc4, pos); @@ -359,7 +373,7 @@ class MMKernel { static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, - const MMArgs& args, const RowPtr& C) { + const MMArgs& args, CRows C_rows) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); const size_t row0 = range_mc.begin(); const size_t mc = range_mc.Num(); @@ -368,7 +382,8 @@ class MMKernel { // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } return; } @@ -377,11 +392,13 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } return; } @@ -389,17 +406,18 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); + LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, + C_rows); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); + LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C); + LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); imc += 1; } HWY_DASSERT(imc == mc); @@ -496,11 +514,11 @@ class MMKernel { // with top-left corner `partial.Row(row_ac) + col_c`. Both A and B must be // BF16 so we can load directly without `Decompress2`, which is expensive for // NUQ and requires 2x unrolling, which requires more loads. - template + template static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, const MMArgs& args, - const RowPtr& C) { + CRows C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; const size_t NBF = hn::Lanes(dbf); @@ -614,11 +632,11 @@ class MMKernel { if (args.add) { MMStoreHorizontalSumsIntoC()( df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C); + C31, C32, C33, row_ac, col_c, args, C_rows); } else { MMStoreHorizontalSumsIntoC()( df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, - C31, C32, C33, row_ac, col_c, args, C); + C31, C32, C33, row_ac, col_c, args, C_rows); } } else { MMAddHorizontalSumsIntoPartial()( @@ -642,27 +660,27 @@ class MMScaleDemoteAdd { template static HWY_INLINE void FillC(const IndexRange& range_mc, const IndexRange& range_nc, const MMArgs& args, - const RowPtr& C) { + CRows C_rows) { size_t row_c = range_mc.begin(); if (args.add) { constexpr bool kAdd = true; if (range_mc.Num() >= 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args, C); + Do4Rows(row_c, range_nc, args, C_rows); } } for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args, C); + Do1Row(row_c, range_nc, args, C_rows); } } else { constexpr bool kAdd = false; if (range_mc.Num() >= 4) { for (; row_c <= range_mc.end() - 4; row_c += 4) { - Do4Rows(row_c, range_nc, args, C); + Do4Rows(row_c, range_nc, args, C_rows); } } for (; row_c < range_mc.end(); ++row_c) { - Do1Row(row_c, range_nc, args, C); + Do1Row(row_c, range_nc, args, C_rows); } } } @@ -671,7 +689,7 @@ class MMScaleDemoteAdd { // Unrolled for 4 rows to reduce the number of loads from `add`. template static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, const RowPtr& C) { + const MMArgs& args, CRows C_rows) { const hn::ScalableTag dd; const hn::Rebind df; // result of DemoteTo const hn::Rebind dc; @@ -685,10 +703,10 @@ class MMScaleDemoteAdd { const double* HWY_RESTRICT pr2 = args.partial.Row(row_c + 2); const double* HWY_RESTRICT pr3 = args.partial.Row(row_c + 3); - TC* HWY_RESTRICT cr0 = C.Row(row_c + 0); - TC* HWY_RESTRICT cr1 = C.Row(row_c + 1); - TC* HWY_RESTRICT cr2 = C.Row(row_c + 2); - TC* HWY_RESTRICT cr3 = C.Row(row_c + 3); + TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; + TC* HWY_RESTRICT cr1 = C_rows[row_c + 1]; + TC* HWY_RESTRICT cr2 = C_rows[row_c + 2]; + TC* HWY_RESTRICT cr3 = C_rows[row_c + 3]; // We manually unroll 2x for higher IPC in batch=1. size_t col_c = range_nc.begin(); @@ -789,7 +807,7 @@ class MMScaleDemoteAdd { // Same as above but handles a single row (for remainder rows). template static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, const RowPtr& C) { + const MMArgs& args, CRows C_rows) { const hn::ScalableTag dd; const hn::Rebind df; // result of DemoteTo const hn::Rebind dc; @@ -798,7 +816,7 @@ class MMScaleDemoteAdd { const size_t ND = hn::Lanes(dd); const VD vscale = hn::Set(dd, args.scale); const double* HWY_RESTRICT pr0 = args.partial.Row(row_c + 0); - TC* HWY_RESTRICT cr0 = C.Row(row_c + 0); + TC* HWY_RESTRICT cr0 = C_rows[row_c + 0]; // We manually unroll 2x for higher IPC in batch=1. size_t col_c = range_nc.begin(); @@ -867,8 +885,8 @@ class MMPerPackage { A_(args_.env->storage.A(pkg_idx, A.Extents())), range_np_(range_np), mr_(config.MR()), - ranges_mc_(config.RangesOfMC(A.Extents().rows)), - ranges_kc_(config.RangesOfKC(A.Extents().cols)), + ranges_mc_(config.RangesOfMC(A.Rows())), + ranges_kc_(config.RangesOfKC(A.Cols())), ranges_nc_(config.RangesOfNC(range_np)), order_(config.Order()), inner_tasks_(config.InnerTasks()), @@ -882,17 +900,16 @@ class MMPerPackage { // B is decompressed several call layers lower, but not all member functions // depend on TB, so pass it as an argument instead of templating the class. template - HWY_NOINLINE void operator()(const MatPtrT& B, - const RowPtr& C) const { + HWY_NOINLINE void operator()(const MatPtrT& B, CRows C_rows) const { switch (order_) { case MMOrder::kNT: - return DoNT(B, C); + return DoNT(B, C_rows); case MMOrder::kNT_K: - return DoNT_K(B, C); + return DoNT_K(B, C_rows); case MMOrder::kNT_MT: - return DoNT_MT(B, C); + return DoNT_MT(B, C_rows); case MMOrder::kNT_MT_K: - return DoNT_MT_K(B, C); + return DoNT_MT_K(B, C_rows); default: HWY_UNREACHABLE; } @@ -913,7 +930,7 @@ class MMPerPackage { // Single M and K, parallel N. Fills all of C directly. template - HWY_INLINE void DoNT(const MatPtrT& B, const RowPtr& C) const { + HWY_INLINE void DoNT(const MatPtrT& B, CRows C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -940,7 +957,7 @@ class MMPerPackage { DecompressB(B, row_b, range_K, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), - args_, C); + args_, C_rows); } }); @@ -949,7 +966,7 @@ class MMPerPackage { // Single M, parallel N, sequential K. Fills all of partial. template - HWY_INLINE void DoNT_K(const MatPtrT& B, const RowPtr& C) const { + HWY_INLINE void DoNT_K(const MatPtrT& B, CRows C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT_K", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -975,7 +992,7 @@ class MMPerPackage { DecompressB(B, row_b, range_kc, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C); + C_rows); } }; @@ -997,13 +1014,13 @@ class MMPerPackage { MMZone fill_zone; if (out_ == MMOut::kCopy) { fill_zone.MaybeEnter("MM.NT_K.FillC", args_); - MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C); + MMScaleDemoteAdd::FillC(range_mc, range_np_, args_, C_rows); } else if (out_ == MMOut::kParM) { fill_zone.MaybeEnter("MM.NT_K.FillC.ParM", args_); args_.env->parallel.ForRangeMC( range_mc, pkg_idx_, [&](size_t row_a) HWY_ATTR { MMScaleDemoteAdd::FillC(IndexRange(row_a, row_a + 1), range_np_, - args_, C); + args_, C_rows); }); } else { HWY_UNREACHABLE; // kDirect is only used with kNT. @@ -1013,7 +1030,7 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. template - HWY_INLINE void DoNT_MT(const MatPtrT& B, const RowPtr& C) const { + HWY_INLINE void DoNT_MT(const MatPtrT& B, CRows C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT_MT", args_); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -1039,7 +1056,7 @@ class MMPerPackage { DecompressB(B, row_b, range_K, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), - args_, C); + args_, C_rows); } }); @@ -1049,7 +1066,7 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. template - HWY_INLINE void DoNT_MT_K(const MatPtrT& B, const RowPtr& C) const { + HWY_INLINE void DoNT_MT_K(const MatPtrT& B, CRows C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT_MT_K", args_); const size_t kc_max = ranges_kc_.TaskSize(); @@ -1074,7 +1091,7 @@ class MMPerPackage { DecompressB(B, row_b, range_kc, B_view); } MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C); + C_rows); } }; // loop_nc args_.env->parallel.ForRangesMC_NC( @@ -1097,7 +1114,7 @@ class MMPerPackage { HWY_DASSERT(out_ == MMOut::kCopy); MMZone fill_zone; fill_zone.MaybeEnter("MM.NT_MT_K.FillC", args_); - MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C); + MMScaleDemoteAdd::FillC(range_mc, range_nc, args_, C_rows); }); } @@ -1258,7 +1275,7 @@ struct MMImpl { // or with the best config. template static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, - const RowPtr& C, const MMArgs& args, + CRows C_rows, const MMArgs& args, const MMConfig& config) { MMZone matmul_zone; matmul_zone.MaybeEnter("MM.DoMatMul", args); @@ -1267,7 +1284,7 @@ struct MMImpl { args.env->parallel.ForPkg( args.per_key->ranges_np.NumTasks(), [&](size_t pkg_idx) { const IndexRange& range_np = args.per_key->ranges_np.Range(pkg_idx); - MMPerPackage(A, args, config, pkg_idx, range_np)(B, C); + MMPerPackage(A, args, config, pkg_idx, range_np)(B, C_rows); }); } }; @@ -1275,7 +1292,7 @@ struct MMImpl { // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // // `A` is a row-major matrix with `M` rows and `B` is transposed. The latter's -// `K = B.Extents().cols`, which must match `A.Extents().cols`, is the number +// `K = B.Cols()`, which must match `A.Cols()`, is the number // of rows in the original B. `N = C.Cols()` must be a multiple of 4. There // are no other restrictions on shape, though performance is better when `M % 4 // == 0` or `M <= 4`. @@ -1295,11 +1312,11 @@ struct MMImpl { template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, - const RowPtr& C) { + CRows C_rows) { const Allocator& allocator = env.ctx.allocator; - const size_t M = A.Extents().rows; - const size_t K = A.Extents().cols; - const size_t N = B.Extents().rows; + const size_t M = A.Rows(); + const size_t K = A.Cols(); + const size_t N = B.Rows(); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); intptr_t index = MMImpl::IndexOfKey(key, env.keys); // First time we see this shape/key. @@ -1323,7 +1340,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add, env.storage.Partial()); if (HWY_LIKELY(tuner.Best())) { - MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); + MMImpl::DoMatMul(A, B, C_rows, args, *tuner.Best()); return &per_key; } @@ -1332,8 +1349,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, // First call: enumerate all feasible configs. if (HWY_UNLIKELY(!tuner.HasCandidates())) { // Ensure matrix dimensions match each other. - HWY_ASSERT(K == B.Extents().cols); - HWY_ASSERT(N == C.Cols()); + HWY_ASSERT(K == B.Cols()); HWY_ASSERT(M <= MMStorage::kMaxM); HWY_ASSERT(K <= MMStorage::kMaxK); HWY_ASSERT(N <= MMStorage::kMaxN); @@ -1347,7 +1363,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMImpl::DoMatMul(A, B, C, args, cfg); + MMImpl::DoMatMul(A, B, C_rows, args, cfg); const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / @@ -1376,6 +1392,19 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, return &per_key; } +// Adapter that fills the row array. This is the common case, whereas only +// GemmaAttention::ComputeQKV uses the arbitrary output rows feature. +template +HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + const RowPtr& C) { + HWY_DASSERT(B.Rows() == C.Cols()); + for (size_t row_ac = 0; row_ac < A.Rows(); ++row_ac) { + env.storage.OutRow(row_ac) = reinterpret_cast(C.Row(row_ac)); + } + return MatMul(A, B, add, env, CRows(&env.storage.OutRow(0))); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index c2474cb..ee6037b 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -193,10 +193,11 @@ class MMStorage { // Internally threaded; must not be called concurrently with the same // `ThreadingContext` (used via `parallel`). MMStorage(const Allocator& allocator, MMParallel& parallel) - // Per-worker copies of `partial` would be wasteful. We instead allocate - // one instance of the maximum matrix extents because threads write at - // false-sharing-free granularity. - : partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), + : out_rows(hwy::AllocateAligned(kMaxM)), + // Per-worker copies of `partial` would be wasteful. We instead allocate + // one instance of the maximum matrix extents because threads write at + // false-sharing-free granularity. + partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), MatPadding::kOdd), // Same stride independent of the actual C.Cols() so we can pre-bind. partial_(partial_storage_.Row(0), kMaxN, partial_storage_.Stride()) { @@ -220,6 +221,8 @@ class MMStorage { BindC(partial_storage_, parallel); } + uint8_t*& OutRow(size_t row_idx) { return out_rows[row_idx]; } + // Returns per-package matrix view. RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); @@ -231,6 +234,11 @@ class MMStorage { RowPtrD Partial() const { return partial_; } private: + // Enables arbitrary output rows. Most callers pass `RowPtr`, which assumes a + // constant stride, but GemmaAttention::ComputeQKV writes to differing KV + // positions per query / output row. `kMaxM` elements are too large for the + // stack, hence dynamic allocation. + hwy::AlignedFreeUniquePtr out_rows; std::unique_ptr> pkg_A_[MMParallel::kMaxPackages]; MatStorageT partial_storage_; RowPtrD partial_;