diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 563be4c..806deda 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -80,20 +80,6 @@ hn::Vec TCFromF32(DC dc, hn::Vec vf) { return hn::DemoteTo(dc, vf); } -// Type-safe wrapper over uint8_t row pointers referenced by MatPtrT. -template -class CRows { - public: - CRows(uint8_t** C_rows) : C_rows_(C_rows) {} - - TC* HWY_RESTRICT operator[](size_t row_idx) const { - return HWY_RCAST_ALIGNED(TC*, C_rows_[row_idx]); - } - - private: - uint8_t** C_rows_; -}; - // Tag classes, passed to `MMKernel::A2C0` to choose between writing one // (all-K) result to C via `MMStoreHorizontalSumsIntoC`, or writing the // first kc result to partial, or accumulating the next kc result into partial @@ -124,7 +110,7 @@ class MMStoreHorizontalSumsIntoC { VF C20, VF C21, VF C22, VF C23, // VF C30, VF C31, VF C32, VF C33, // const size_t row_c, const size_t col_c, - const MMArgs& args, CRows C_rows) const { + const MMArgs& args, RowPtrs C_rows) const { HWY_ALIGN float buf[16 * hn::MaxLanes(df)]; const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing @@ -202,7 +188,7 @@ class MMStoreHorizontalSumsIntoC { template , typename TC> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, - VF4 vadd, CRows C_rows, + VF4 vadd, RowPtrs C_rows, const size_t row_c, const size_t col_c) { if constexpr (kRow < kRowsAC) { @@ -236,7 +222,7 @@ class MMAddHorizontalSumsIntoPartial { VF F20, VF F21, VF F22, VF F23, // VF F30, VF F31, VF F32, VF F33, // const size_t row_c, const size_t col_c, - const RowPtrD& partial) const { + const StridedViewD& partial) const { // We accumulate in 64-bit to avoid loss of precision. static_assert(HWY_HAVE_FLOAT64, "Disable Armv7 NEON: we require fp64"); @@ -342,7 +328,8 @@ class MMAddHorizontalSumsIntoPartial { } template > - static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, const RowPtrD& partial, + static HWY_INLINE void MaybeAddStore(D4 d4, V4 sum, + const StridedViewD& partial, const size_t row_c, const size_t col_c) { if constexpr (kRow < kRowsAC) { double* HWY_RESTRICT pos = partial.Row(row_c + kRow) + col_c; @@ -371,10 +358,11 @@ class MMKernel { // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. template - static HWY_INLINE void A2C0(const RowPtrBF& A_view, const RowPtrBF& B_view, - size_t mr, const IndexRange& range_mc, - const size_t row_b, size_t kc, Tag tag, - const MMArgs& args, CRows C_rows) { + static HWY_INLINE void A2C0(const StridedViewBF& A_view, + const StridedViewBF& B_view, size_t mr, + const IndexRange& range_mc, const size_t row_b, + size_t kc, Tag tag, const MMArgs& args, + RowPtrs C_rows) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); const size_t row0 = range_mc.begin(); const size_t mc = range_mc.Num(); @@ -516,10 +504,10 @@ class MMKernel { // BF16 so we can load directly without `Decompress2`, which is expensive for // NUQ and requires 2x unrolling, which requires more loads. template - static HWY_INLINE void LoopKC(const RowPtrBF& A_view, const RowPtrBF& B_view, - size_t row_ac, size_t imc, size_t col_c, - size_t kc, Tag tag, const MMArgs& args, - CRows C_rows) { + static HWY_INLINE void LoopKC(const StridedViewBF& A_view, + const StridedViewBF& B_view, size_t row_ac, + size_t imc, size_t col_c, size_t kc, Tag tag, + const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; const size_t NBF = hn::Lanes(dbf); @@ -661,7 +649,7 @@ class MMScaleDemoteAdd { template static HWY_INLINE void FillC(const IndexRange& range_mc, const IndexRange& range_nc, const MMArgs& args, - CRows C_rows) { + RowPtrs C_rows) { size_t row_c = range_mc.begin(); if (args.add) { constexpr bool kAdd = true; @@ -690,7 +678,7 @@ class MMScaleDemoteAdd { // Unrolled for 4 rows to reduce the number of loads from `add`. template static HWY_INLINE void Do4Rows(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, CRows C_rows) { + const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dd; const hn::Rebind df; // result of DemoteTo const hn::Rebind dc; @@ -808,7 +796,7 @@ class MMScaleDemoteAdd { // Same as above but handles a single row (for remainder rows). template static HWY_INLINE void Do1Row(size_t row_c, const IndexRange& range_nc, - const MMArgs& args, CRows C_rows) { + const MMArgs& args, RowPtrs C_rows) { const hn::ScalableTag dd; const hn::Rebind df; // result of DemoteTo const hn::Rebind dc; @@ -901,7 +889,7 @@ class MMPerPackage { // B is decompressed several call layers lower, but not all member functions // depend on TB, so pass it as an argument instead of templating the class. template - HWY_NOINLINE void operator()(const MatPtrT& B, CRows C_rows) const { + HWY_NOINLINE void operator()(const MatPtrT& B, RowPtrs C_rows) const { switch (order_) { case MMOrder::kNT: return DoNT(B, C_rows); @@ -931,7 +919,7 @@ class MMPerPackage { // Single M and K, parallel N. Fills all of C directly. template - HWY_INLINE void DoNT(const MatPtrT& B, CRows C_rows) const { + HWY_INLINE void DoNT(const MatPtrT& B, RowPtrs C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -939,7 +927,7 @@ class MMPerPackage { const IndexRange& range_M = ranges_mc_.Range(0); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); - const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K); + const StridedViewBF& A_view = A_.View(range_M.begin(), 0, K); const size_t B_stride = Stride(MatPadding::kOdd, K, sizeof(BF16), line_bytes_); @@ -948,11 +936,12 @@ class MMPerPackage { range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, [&](const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_storage_view(B_storage, K, B_stride); + const StridedViewBF B_storage_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view); + StridedViewBF B_view = + DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), args_, C_rows); } @@ -963,7 +952,7 @@ class MMPerPackage { // Single M, parallel N, sequential K. Fills all of partial. template - HWY_INLINE void DoNT_K(const MatPtrT& B, CRows C_rows) const { + HWY_INLINE void DoNT_K(const MatPtrT& B, RowPtrs C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT_K", args_); HWY_DASSERT(ranges_mc_.NumTasks() == 1); @@ -976,14 +965,15 @@ class MMPerPackage { const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); - const RowPtrBF B_storage_view( + const StridedViewBF& A_view = + A_.View(range_mc.begin(), range_kc.begin(), kc); + const StridedViewBF B_storage_view( B_storage, kc, Stride(MatPadding::kOdd, kc, sizeof(BF16), line_bytes_)); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); + StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, C_rows); } @@ -1023,7 +1013,7 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, single K. // Fills `mc x nc` sections of C directly, in parallel. template - HWY_INLINE void DoNT_MT(const MatPtrT& B, CRows C_rows) const { + HWY_INLINE void DoNT_MT(const MatPtrT& B, RowPtrs C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT_MT", args_); HWY_DASSERT(ranges_kc_.NumTasks() == 1); @@ -1037,13 +1027,14 @@ class MMPerPackage { args_.env->parallel.ForRangesMC_NC( ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { - const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K); + const StridedViewBF& A_view = A_.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_storage_view(B_storage, K, B_stride); + const StridedViewBF B_storage_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - RowPtrBF B_view = DecompressB(B, row_b, range_K, B_storage_view); + StridedViewBF B_view = + DecompressB(B, row_b, range_K, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), args_, C_rows); } @@ -1055,7 +1046,7 @@ class MMPerPackage { // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Fills `mc x nc` sections of `partial`, then `C`, in parallel. template - HWY_INLINE void DoNT_MT_K(const MatPtrT& B, CRows C_rows) const { + HWY_INLINE void DoNT_MT_K(const MatPtrT& B, RowPtrs C_rows) const { MMZone zone; zone.MaybeEnter("MM.NT_MT_K", args_); const size_t kc_max = ranges_kc_.TaskSize(); @@ -1065,17 +1056,18 @@ class MMPerPackage { // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. - const auto loop_nc = [&](const RowPtrBF& B_storage_view, + const auto loop_nc = [&](const StridedViewBF& B_storage_view, const IndexRange& range_mc, const IndexRange& range_kc, const IndexRange& range_nc, auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); - const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); + const StridedViewBF& A_view = + A_.View(range_mc.begin(), range_kc.begin(), kc); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - RowPtrBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); + StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, C_rows); } @@ -1084,7 +1076,7 @@ class MMPerPackage { ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_storage_view(B_storage, kc_max, B_stride); + const StridedViewBF B_storage_view(B_storage, kc_max, B_stride); // Peel off the first iteration of the kc loop: avoid // zero-initializing `partial` by writing into it. @@ -1166,15 +1158,16 @@ class MMPerPackage { // Autotuning wrapper for `DoDecompressA`. template - HWY_INLINE RowPtrBF DecompressA(const MatPtrT& A) const { + HWY_INLINE StridedViewBF DecompressA(const MatPtrT& A) const { MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; // If already BF16, maybe return a view: if constexpr (hwy::IsSame()) { // Only if vector multiple and padded (see `DoDecompressA`). const size_t NBF = hn::Lanes(hn::ScalableTag()); if (HWY_LIKELY(A.Cols() % NBF == 0 && !A.IsPacked())) { - // Actually const, but RowPtr is also used for partial which is not. - return RowPtrBF(const_cast(A.Row(0)), A.Cols(), A.Stride()); + // Const, but cast because StridedView is also used for `partial` which + // is non-const. + return StridedViewBF(const_cast(A.Row(0)), A.Cols(), A.Stride()); } } @@ -1212,12 +1205,12 @@ class MMPerPackage { // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` // thanks to its large table lookups, and less so on other targets. template - HWY_INLINE RowPtrBF DecompressB(const MatPtrT& B, const size_t row_b, - const IndexRange& range_kc, - const RowPtrBF& B_view) const { + HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, + const StridedViewBF& B_view) const { if constexpr (hwy::IsSame()) { - return RowPtrBF(const_cast(B.Row(row_b)) + range_kc.begin(), - range_kc.Num(), B.Stride()); + return StridedViewBF(const_cast(B.Row(row_b)) + range_kc.begin(), + range_kc.Num(), B.Stride()); } const hn::ScalableTag dbf; @@ -1242,7 +1235,7 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; - RowPtrBF A_; // view into A or pkg_A_, both of which are padded. + StridedViewBF A_; // view into A or pkg_A_, both of which are padded. const IndexRange range_np_; // From MMConfig: @@ -1272,7 +1265,7 @@ struct MMImpl { // or with the best config. template static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const MatPtrT& B, - CRows C_rows, const MMArgs& args, + RowPtrs C_rows, const MMArgs& args, const MMConfig& config) { MMZone matmul_zone; matmul_zone.MaybeEnter("MM.DoMatMul", args); @@ -1314,7 +1307,7 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C) { - CRows C_rows(C.GetRowPtrs()); + RowPtrs C_rows(C.GetRowPtrs()); if (HWY_UNLIKELY(!C.GetRowPtrs())) { if constexpr (HWY_IS_DEBUG_BUILD) { fprintf(stderr, @@ -1326,7 +1319,7 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, for (size_t r = 0; r < C.Rows(); ++r) { env.row_ptrs[0][r] = reinterpret_cast(C.Row(r)); } - C_rows = CRows(env.row_ptrs[0].get()); + C_rows = RowPtrs(env.row_ptrs[0].get()); } const Allocator& allocator = env.ctx.allocator; diff --git a/ops/matmul.h b/ops/matmul.h index 30ed703..bb35fa2 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -176,12 +176,12 @@ void BindB(MatPtr& B, size_t sizeof_TC, MMParallel& parallel); // C is BF16/float, or double for partial. void BindC(MatPtr& C, MMParallel& parallel); -// Lightweight view into `MatStorageT`. +// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. #pragma pack(push, 1) // power of two size template -class RowPtr { +class StridedView { public: - RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) + StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) : row0_(row0), cols_(static_cast(cols)), stride_(static_cast(stride)) { @@ -198,10 +198,10 @@ class RowPtr { } // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - RowPtr View(size_t r, size_t c, size_t cols) const { + StridedView View(size_t r, size_t c, size_t cols) const { HWY_DASSERT(c < Cols()); HWY_DASSERT(cols <= Cols() - c); - return RowPtr(Row(r) + c, cols, stride_); + return StridedView(Row(r) + c, cols, stride_); } private: @@ -211,8 +211,8 @@ class RowPtr { }; #pragma pack(pop) -using RowPtrBF = RowPtr; -using RowPtrD = RowPtr; +using StridedViewBF = StridedView; +using StridedViewD = StridedView; // Per-package storage for packed A, and one global C-shaped `partial` for // accumulating partial dot products (sections of K). @@ -260,19 +260,19 @@ class MMStorage { } // Returns per-package matrix view. - RowPtrBF A(size_t pkg_idx, const Extents2D& extents) const { + StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.cols <= kMaxK); - return RowPtrBF(const_cast(pkg_A_[pkg_idx]->Row(0)), extents.cols, - pkg_A_[pkg_idx]->Stride()); + return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), + extents.cols, pkg_A_[pkg_idx]->Stride()); } - RowPtrD Partial() const { return partial_; } + StridedViewD Partial() const { return partial_; } private: std::unique_ptr> pkg_A_[MMParallel::kMaxPackages]; MatStorageT partial_storage_; - RowPtrD partial_; + StridedViewD partial_; }; //------------------------------------------------------------------------------ @@ -673,7 +673,7 @@ struct MatMulEnv { // Reduces register pressure compared to individual values/references. struct MMArgs { MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add, const RowPtrD& partial) + const float* HWY_RESTRICT add, const StridedViewD& partial) : env(&env), per_key(&per_key), scale(scale), @@ -686,7 +686,7 @@ struct MMArgs { double scale; const float* HWY_RESTRICT add; // Same size as C, threads write at false-sharing-free granularity. - RowPtrD partial; + StridedViewD partial; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. diff --git a/util/mat.h b/util/mat.h index a61bedd..5fe07b6 100644 --- a/util/mat.h +++ b/util/mat.h @@ -33,16 +33,19 @@ namespace gcpp { -// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. -template -class CRows { +// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used +// for C, in future also for A. +template +class RowPtrs { public: - CRows(TC** C_rows) : C_rows_(C_rows) {} + RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {} - TC* HWY_RESTRICT operator[](size_t row_idx) const { return C_rows_[row_idx]; } + T* HWY_RESTRICT operator[](size_t row_idx) const { + return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]); + } private: - TC** C_rows_; + uint8_t** row_ptrs_; }; // Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector @@ -88,7 +91,9 @@ class MatPtr : public IFields { bool HasPtr() const { return ptr_ != nullptr; } - // Caller has initialized Rows() pointers in row_ptrs[]. + // Caller has initialized Rows() pointers in row_ptrs[]. Note that this only + // changes `GetRowPtrs`, not `Row()`, because that would require branching + // and only a few call sites, in particular MatMul, use row pointers. void AttachRowPtrs(uint8_t** row_ptrs) { row_ptrs_ = row_ptrs; for (size_t r = 0; r < Rows(); ++r) { @@ -96,6 +101,8 @@ class MatPtr : public IFields { } } + // Called by Activations to allocate once, rather than have to fill row + // pointers in each call to MatMul. void AllocateAndAttachRowPtrs( std::vector>& row_ptrs) { if (!HasPtr()) return; @@ -107,6 +114,7 @@ class MatPtr : public IFields { AttachRowPtrs(ptrs); }; + // If non-null, this array should be used instead of `Row()`. uint8_t** GetRowPtrs() const { return row_ptrs_; } // A single row counts as packed because there is no padding between rows.