diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index bf7bd68..3a20690 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -148,21 +148,21 @@ class MMStoreHorizontalSumsIntoC { } } - // Scales the dot-product terms and adds bias (if present) and stores the - // four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is - // `MMSetC`, the vectors are written as-is (first call, or small K). - // Otherwise, they are partial sums and are accumulated into C. - template , class Tag, class CRows> - HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, - const size_t row_c, const size_t col_c, - const MMArgs& args, CRows C_rows) const { - const V4 vscale = hn::Set(d4, args.scale); + // Scales the dot-product terms plus `add` (if non-null) and stores the four + // 4-wide vectors to `C` starting at row 0, column 0. If `tag` is `MMSetC`, + // the vectors are written as-is (first call, or small K). Otherwise, they + // are partial sums and are accumulated into C. + template , class Tag, class CView> + HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, + const float scale, const float* HWY_RESTRICT add, + const size_t imc, Tag tag, CView C_rows) const { + const V4 vscale = hn::Set(d4, scale); HWY_ALIGN static constexpr float kZero[4] = {}; - const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); - MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, C_rows, row_c, col_c); - MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, C_rows, row_c, col_c); - MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, C_rows, row_c, col_c); - MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, C_rows, row_c, col_c); + const V4 vadd = hn::Load(d4, add ? add : kZero); + MaybeScaleAndStore<0>(d4, sum0, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<1>(d4, sum1, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<2>(d4, sum2, vscale, vadd, tag, imc, C_rows); + MaybeScaleAndStore<3>(d4, sum3, vscale, vadd, tag, imc, C_rows); } private: @@ -199,13 +199,13 @@ class MMStoreHorizontalSumsIntoC { } template , - class Tag, typename TC> + class Tag, class CView> static HWY_INLINE void MaybeScaleAndStore(DF4 df4, VF4 sum, VF4 vscale, - VF4 vadd, Tag, RowPtrs C_rows, - const size_t row_c, - const size_t col_c) { + VF4 vadd, Tag, const size_t imc, + CView C_view) { if constexpr (kRow < kRowsAC) { - TC* HWY_RESTRICT pos = C_rows[row_c + kRow] + col_c; + using TC = hwy::RemoveCvRef; + TC* HWY_RESTRICT pos = C_view.Row(imc + kRow); const hn::Rebind dc4; if constexpr (hwy::IsSame()) { vadd = F32FromTC(dc4, hn::Load(dc4, pos)); // load prior value @@ -234,7 +234,7 @@ class MMDecompress { // Neither A nor B require padding because `LoopKC` handles remainders. if constexpr (hwy::IsSame()) { - return View(B, row_b, range_kc.begin(), range_kc.Num()); + return StridedViewBF(B, row_b, range_kc.begin(), range_kc.Num()); } const PackedSpan B_span = B.PaddedSpan(); @@ -264,7 +264,7 @@ class MMDecompress { if constexpr (IsBF16()) { // We can use a view, regardless of columns/padding, because // `MMKernel::LoopKC` supports non-vector multiples. - return View(A, 0, 0, A.Cols()); + return StridedViewBF(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. We also only @@ -277,15 +277,6 @@ class MMDecompress { } private: - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - // Decompresses all `M x K` from `A` into padded BF16 `A_view`. static HWY_NOINLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view, @@ -402,26 +393,26 @@ class MMKernel { kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); public: - // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` - // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. + // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` is + // `mc x kc` and `B_view` is `(kNR x kc)`. All views, including `add`, start + // at row/col 0. `CView` is either `RowPtrs` or `StridedView`. // Called by B3A2C0 and by callers that hoist `A_view`. - template + template static HWY_INLINE void A2C0(const StridedViewBF A_view, const StridedViewBF B_view, size_t mr, - const IndexRange& range_mc, const size_t row_b, - size_t kc, Tag tag, const MMArgs& args, - CRows C_rows) { + const IndexRange& range_mc, size_t kc, + const float scale, const float* HWY_RESTRICT add, + Tag tag, CView C_view) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); - const size_t row0 = range_mc.begin(); + const size_t mc = range_mc.Num(); size_t imc = 0; // M == 1, or x86 with 8 SIMD registers: if (HWY_UNLIKELY(mr == 1)) { for (; imc < mc; ++imc) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); } return; } @@ -430,13 +421,11 @@ class MMKernel { if (HWY_UNLIKELY(mr == 2)) { if (HWY_LIKELY(mc >= 2)) { for (; imc <= mc - 2; imc += 2) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); } } if (HWY_UNLIKELY(imc != mc)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); } return; } @@ -444,18 +433,17 @@ class MMKernel { HWY_DASSERT(mr == 4); if (HWY_LIKELY(mc >= 4)) { for (; imc <= mc - 4; imc += 4) { - LoopKC<4>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, - C_rows); + LoopKC<4>(A_view, B_view, imc, kc, scale, add, tag, C_view); } } const size_t remainder_mc = mc - imc; HWY_DASSERT(remainder_mc < 4); if (HWY_UNLIKELY(remainder_mc & 2)) { - LoopKC<2>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<2>(A_view, B_view, imc, kc, scale, add, tag, C_view); imc += 2; } if (HWY_UNLIKELY(remainder_mc & 1)) { - LoopKC<1>(A_view, B_view, row0 + imc, imc, row_b, kc, tag, args, C_rows); + LoopKC<1>(A_view, B_view, imc, kc, scale, add, tag, C_view); imc += 1; } HWY_DASSERT(imc == mc); @@ -466,11 +454,11 @@ class MMKernel { // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by // `ForeachKC` and when there is only a single KC task. - template + template static void B3A2C0(const StridedViewBF A, const MatPtrT& B, - const MMArgs& args, const IndexRange& range_mc, - const IndexRange& range_kc, const IndexRange& range_nc, - size_t mr, Tag out_tag, CRows C_rows) { + const IndexRange& range_mc, const IndexRange& range_kc, + const IndexRange& range_nc, const MMArgs& args, + Tag out_tag, RowPtrs C) { HWY_ALIGN BF16 B_storage[B_storage_max]; const size_t kc = range_kc.Num(); @@ -482,24 +470,28 @@ class MMKernel { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = + const StridedViewBF B_view = MMDecompress::DecompressB(B, row_b, range_kc, B_storage_view); - A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); + const RowPtrs C_view = C.View(range_mc.begin(), row_b); + const float* HWY_RESTRICT add = args.add ? args.add + row_b : nullptr; + A2C0(A_view, B_view, args.mr, range_mc, kc, args.scale, add, out_tag, + C_view); } } - template + template static void ForeachKC(const StridedViewBF A, const MatPtrT& B, - const MMArgs& args, const IndexRange& range_mc, + const IndexRange& range_mc, const IndexRangePartition& ranges_kc, - const IndexRange& range_nc, size_t mr, CRows C_rows) { + const IndexRange& range_nc, const MMArgs& args, + RowPtrs C) { // Peel off the first iteration of the kc loop: avoid zero-initializing `C` // by writing directly into it, and later accumulating into it. ranges_kc.VisitFirst([&](const IndexRange& range_kc) { - B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMSetC(), C); }); ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { - B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows); + B3A2C0(A, B, range_mc, range_kc, range_nc, args, MMAddC(), C); }); } @@ -593,19 +585,20 @@ class MMKernel { // Innermost loop over `kc` columns (typically 1024-4096, not necessarily a // multiple of `NBF`) in steps of one vector, for `kRowsAC` rows of `A_view` // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). - // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. - // `A` and `B` are always BF16, `C` can be F32 or BF16. - template + // Updates a `kRowsAC x kNR` tile in `C_view` starting at row `imc`, column 0. + // `A` and `B` are always BF16, `C` can be F32 or BF16. `add` is also + // relative to the C column. + template static HWY_INLINE void LoopKC(const StridedViewBF A_view, - const StridedViewBF B_view, size_t row_ac, - size_t imc, size_t col_c, size_t kc, Tag tag, - const MMArgs& args, CRows C_rows) { + const StridedViewBF B_view, size_t imc, + size_t kc, const float scale, + const float* HWY_RESTRICT add, Tag tag, + CView C_view) { const hn::ScalableTag dbf; using VBF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); HWY_DASSERT(kRowsAC <= kMaxMR); - HWY_DASSERT(col_c % kNR == 0); // Rows are aligned to `kMaxMR`, except for the last tile of A. // `kRowsAC` rows of A (null for the rest) and `kNR` rows of B. @@ -784,7 +777,7 @@ class MMKernel { hn::Vec sum0, sum1, sum2, sum3; horz.Reduce4x4(df, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33, sum0, sum1, sum2, sum3); - horz.Store(d4, sum0, sum1, sum2, sum3, tag, row_ac, col_c, args, C_rows); + horz.Store(d4, sum0, sum1, sum2, sum3, scale, add, imc, tag, C_view); } }; @@ -884,7 +877,7 @@ class MMLoops { // or with the best config. template static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); PROFILER_ZONE3(args.env.ctx.profiler, args.env.ctx.Worker(args.options.cluster_idx), zone); @@ -892,7 +885,7 @@ class MMLoops { DispatchParallelism( args.options.parallelism, [&](const auto& parallel) HWY_ATTR { DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { - Loop(order, parallel, A, B, C_rows, args); + Loop(order, parallel, A, B, C, args); }); }); } @@ -904,11 +897,11 @@ class MMLoops { return HWY_MAX(kNR, line_bytes / sizeof_TC); } - // Single M and K ranges, parallel N. Fills all of C directly. + // Single M and K ranges, parallel N. template static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); @@ -932,10 +925,21 @@ class MMLoops { for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { - StridedViewBF B_view = + const StridedViewBF B_view = MMDecompress::DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), - args, C_rows); + const RowPtrs C_view = C.View(range_M.begin(), row_b); + const float* HWY_RESTRICT add = + args.add ? args.add + row_b : nullptr; + + MMKernel::A2C0(A_view, B_view, args.mr, range_M, K, args.scale, add, + MMSetC(), C_view); + } + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_M, range_nc, C2, worker); + } } }); } @@ -944,7 +948,7 @@ class MMLoops { template static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); HWY_DASSERT(args.ranges_mc.NumTasks() == 1); const IndexRange& range_mc = args.ranges_mc.Range(0); @@ -955,17 +959,24 @@ class MMLoops { [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, - range_nc, args.mr, C_rows); + MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, + range_nc, args, C); + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_mc, range_nc, C2, worker); + } + } }); } // Parallel loops over mc/nc blocks of M/range_n, single K. - // Fills `mc x nc` sections of C directly, in parallel. + // Fills `mc x nc` sections of C. template static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); HWY_DASSERT(args.ranges_kc.NumTasks() == 1); const IndexRange& range_K = args.ranges_kc.Range(0); @@ -976,17 +987,24 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr, - MMSetC(), C_rows); + MMKernel::B3A2C0(A, B, range_mc, range_K, range_nc, args, MMSetC(), + C); + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_mc, range_nc, C2, worker); + } + } }); } - // Parallel loops over mc/nc blocks of M/range_np, sequential K. + // Parallel loops over mc/nc blocks of M/range_n, sequential K. // Accumulates into `mc x nc` sections of `C`. template static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows, const MMArgs& args) { + RowPtrs C, const MMArgs& args) { static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); parallel.ForRangesMC_NC( @@ -995,8 +1013,15 @@ class MMLoops { size_t worker) HWY_ATTR { MMZone mm_zone; mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc, - args.mr, C_rows); + MMKernel::ForeachKC(A, B, range_mc, args.ranges_kc, range_nc, args, + C); + + if constexpr (IsBF16()) { + if (args.options.fused) { + StridedViewBF C2(nullptr, 0, 0); + args.options.fused(C, range_mc, range_nc, C2, worker); + } + } }); } }; // MMLoops diff --git a/ops/matmul.h b/ops/matmul.h index a85d192..93e7b04 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -60,54 +60,6 @@ HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink? // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; -// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. -// Also used to decompress B, hence non-const. -#pragma pack(push, 1) // power of two size -template -class StridedView { - public: - StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), - cols_(static_cast(cols)), - stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); - } - - T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } - size_t Cols() const { return static_cast(cols_); } - - size_t Stride() const { return static_cast(stride_); } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - StridedView View(size_t r, size_t c, size_t cols) const { - HWY_DASSERT(c < Cols()); - HWY_DASSERT(cols <= Cols() - c); - return StridedView(Row(r) + c, cols, stride_); - } - - private: - T* HWY_RESTRICT row0_; - uint32_t cols_; - uint32_t stride_; -}; -#pragma pack(pop) - -using StridedViewBF = StridedView; -using StridedViewD = StridedView; - -using MMFused = std::function; - -struct MMOptions { - uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. - ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; - - MMFused fused; -}; - // Policy classes for parallelism, implementing some of `ParallelismStrategy`. struct MMParallelNone { @@ -735,6 +687,19 @@ struct MatMulEnv { std::vector> row_ptrs; }; +// Called with the entire C matrix, the sub-ranges of M (rows) and N (cols) +// that this thread has just filled, a view into a second tile (only for the +// upcoming `GatedMatmul`), and the worker thread index (see `ParallelFor`). +using MMFused = std::function; + +struct MMOptions { + uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. + ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; + + MMFused fused; // called if non-null and `TC` is BF16. +}; + // Arguments to MatMul() that are independent of the A/B/C types. Reduces // register pressure compared to individual values/references. Also used for // passing through `DispatchOrder`. diff --git a/util/mat.h b/util/mat.h index c8a4617..4360b69 100644 --- a/util/mat.h +++ b/util/mat.h @@ -38,17 +38,27 @@ namespace gcpp { template class RowPtrs { public: - RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {} + RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {} + + RowPtrs View(size_t r, size_t c) { + RowPtrs view(row_ptrs_); + view.r0_ = static_cast(r); + view.c0_ = static_cast(c); + return view; + } T* HWY_RESTRICT Row(size_t row_idx) const { - return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]); + return HWY_RCAST_ALIGNED(T*, row_ptrs_[r0_ + row_idx]) + c0_; } - T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); } private: uint8_t** row_ptrs_; + uint32_t r0_; + uint32_t c0_; }; +using RowPtrsBF = RowPtrs; + // Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector // or matrix). Base class of the non-type-erased `MatPtrT`. Use this class // to store hetereogeneous tensor references in a vector. @@ -349,12 +359,12 @@ RowPtrs GetOrSetTempRowPtrs( template decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, Args&&... args) { -#if GEMMA_ENABLE_NUQ - if (base->GetType() == Type::kNUQ) { - const MatPtrT mat(*base); - return func(&mat, std::forward(args)...); + if constexpr (GEMMA_ENABLE_NUQ) { + if (base->GetType() == Type::kNUQ) { + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); + } } -#endif // GEMMA_ENABLE_NUQ if (base->GetType() == Type::kF32) { const MatPtrT mat(*base); @@ -376,13 +386,13 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, const Func& func, Args&&... args) { HWY_DASSERT(base1->GetType() == base2->GetType()); -#if GEMMA_ENABLE_NUQ - if (base1->GetType() == Type::kNUQ) { - const MatPtrT mat1(*base1); - const MatPtrT mat2(*base2); - return func(&mat1, &mat2, std::forward(args)...); + if constexpr (GEMMA_ENABLE_NUQ) { + if (base1->GetType() == Type::kNUQ) { + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); + } } -#endif // GEMMA_ENABLE_NUQ if (base1->GetType() == Type::kF32) { const MatPtrT mat1(*base1); @@ -508,5 +518,51 @@ class MatFactory { MatPadding padding_; }; +// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. +// Also used to decompress B, hence non-const. +#pragma pack(push, 1) // power of two size +template +class StridedView { + public: + StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), + cols_(static_cast(cols)), + stride_(static_cast(stride)) { + HWY_DASSERT(stride >= cols); + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + StridedView(const MatPtrT& mat, size_t r, size_t c, size_t cols) + : StridedView(const_cast(mat.Row(r)) + c, cols, mat.Stride()) { + HWY_DASSERT(c < mat.Cols()); + HWY_DASSERT(cols <= mat.Cols() - c); + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + StridedView View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < Cols()); + HWY_DASSERT(cols <= Cols() - c); + return StridedView(Row(r) + c, cols, stride_); + } + + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } + size_t Cols() const { return static_cast(cols_); } + + size_t Stride() const { return static_cast(stride_); } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + } + + private: + T* HWY_RESTRICT row0_; + uint32_t cols_; + uint32_t stride_; +}; +#pragma pack(pop) + +using StridedViewBF = StridedView; +using StridedViewD = StridedView; + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_