From 461a9c7d1b18a269d505769e934af4dc68eaae4d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 9 Sep 2025 07:13:03 -0700 Subject: [PATCH] Matmul refactoring towards fusion MMLoops: move dispatch code out, use overloads split build target into matmul_env (for MatMulEnv/MMOptions) weights: no longer call BindB Fix potential out of bounds in gemma_batch_bench PiperOrigin-RevId: 804895985 --- .github/workflows/build.yml | 1 + BUILD.bazel | 41 +- evals/gemma_batch_bench.cc | 3 +- examples/simplified_gemma/BUILD.bazel | 2 +- gemma/gemma_args.h | 3 +- gemma/weights.cc | 2 - ops/matmul-inl.h | 541 +++++++++++--------------- ops/matmul.h | 180 ++++++--- ops/matvec-inl.h | 1 - util/mat.h | 3 +- 10 files changed, 388 insertions(+), 389 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2052a82..1512548 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -46,6 +46,7 @@ jobs: -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} -D CMAKE_C_COMPILER_LAUNCHER=ccache -D CMAKE_CXX_COMPILER_LAUNCHER=ccache + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 - name: Build run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4 diff --git a/BUILD.bazel b/BUILD.bazel index 52c2df3..dbe52b7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -238,7 +238,6 @@ cc_library( ":configs", ":gemma_args", ":mat", - ":matmul", ":model_store", ":tensor_info", ":threading_context", @@ -271,14 +270,33 @@ test_suite( ) cc_library( - name = "matmul", + name = "matmul_env", srcs = ["ops/matmul.cc"], hdrs = ["ops/matmul.h"], + deps = [ + ":allocator", + ":basics", + ":configs", + ":mat", + ":threading", + ":threading_context", + "@highway//:bit_set", + "@highway//:hwy", + "@highway//:nanobenchmark", + "@highway//:profiler", + ], +) + +cc_library( + name = "matmul", + # allow depending only on this target, without also matmul_env. + hdrs = ["ops/matmul.h"], textual_hdrs = ["ops/matmul-inl.h"], deps = [ ":allocator", ":basics", ":mat", + ":matmul_env", ":threading", ":threading_context", "//compression:compress", @@ -310,6 +328,7 @@ cc_library( ":basics", ":mat", ":matmul", + ":matmul_env", ":threading_context", "//compression:compress", "//compression:types", @@ -333,11 +352,12 @@ cc_library( ":allocator", ":basics", ":mat", - ":matmul", + ":matmul_env", # MMOptions ":matmul_static", ":threading_context", "//compression:compress", "@highway//:algo", + "@highway//:bit_set", "@highway//:hwy", "@highway//:math", "@highway//:matvec", @@ -434,7 +454,7 @@ cc_test( deps = [ ":basics", ":mat", - ":matmul", + ":matmul_env", ":matmul_static", ":ops", ":threading_context", @@ -462,7 +482,8 @@ cc_test( ], deps = [ ":basics", - ":matmul", + ":matmul_env", + ":matmul_static", ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", @@ -495,7 +516,6 @@ cc_library( ":args", ":basics", ":mat", - ":matmul", "//io", "@highway//:hwy", "@highway//:profiler", @@ -523,13 +543,12 @@ cc_library( "gemma/gemma-inl.h", ], deps = [ - ":allocator", ":basics", ":configs", ":gemma_args", ":kv_cache", ":mat", - ":matmul", + ":matmul_env", ":model_store", ":ops", ":threading", @@ -569,7 +588,7 @@ cc_library( ":cross_entropy", ":gemma_args", ":gemma_lib", - ":matmul", + ":matmul_env", ":ops", ":threading_context", ":tokenizer", @@ -600,7 +619,7 @@ cc_library( ":gemma_args", ":gemma_lib", ":kv_cache", - ":matmul", + ":matmul_env", ":threading", ":threading_context", ":tokenizer", @@ -661,7 +680,7 @@ cc_binary( ":benchmark_helper", ":gemma_args", ":gemma_lib", - ":matmul", + ":matmul_env", ":tokenizer", "//compression:types", "//paligemma:image", diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 6d97c61..135c2bb 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -93,7 +93,8 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) { if (qpos == questions.size()) qpos = 0; } std::vector responses = BatchGemmaReply(inputs); - for (size_t i = 0; i < hwy::Unpredictable1() * 3; ++i) { + for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); + ++i) { fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str()); } diff --git a/examples/simplified_gemma/BUILD.bazel b/examples/simplified_gemma/BUILD.bazel index 740ec7d..811906f 100644 --- a/examples/simplified_gemma/BUILD.bazel +++ b/examples/simplified_gemma/BUILD.bazel @@ -15,7 +15,7 @@ cc_library( deps = [ "//:gemma_args", "//:gemma_lib", - "//:matmul", + "//:matmul_env", "//:threading_context", "//:tokenizer", "@highway//:hwy", diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index b2d19ff..3135f50 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -24,8 +24,7 @@ #include #include -#include "io/io.h" // Path -#include "ops/matmul.h" // MMStorage::kMax* +#include "io/io.h" // Path #include "util/args.h" #include "util/basics.h" // Tristate #include "util/mat.h" diff --git a/gemma/weights.cc b/gemma/weights.cc index b71e6b7..425a752 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -30,7 +30,6 @@ #include "gemma/gemma_args.h" #include "gemma/model_store.h" #include "io/blob_store.h" -#include "ops/matmul.h" // MMParallel #include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" @@ -338,7 +337,6 @@ static void AllocateAndBindAll(std::vector& tensors, owners[start + task].AllocateFor(*tensor.mat, ctx.allocator, tensor.padding); - BindB(ctx, *tensor.mat, tensor.mat->ElementBytes()); }); } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index f2e9c49..21ac14b 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -152,10 +152,10 @@ class MMStoreHorizontalSumsIntoC { // four 4-wide vectors to `C` starting at `(row_c, col_c)`. If `tag` is // `MMSetC`, the vectors are written as-is (first call, or small K). // Otherwise, they are partial sums and are accumulated into C. - template , class Tag, typename TC> + template , class Tag, class CRows> HWY_INLINE void Store(D4 d4, V4 sum0, V4 sum1, V4 sum2, V4 sum3, Tag tag, const size_t row_c, const size_t col_c, - const MMArgs& args, RowPtrs C_rows) const { + const MMArgs& args, CRows C_rows) const { const V4 vscale = hn::Set(d4, args.scale); HWY_ALIGN static constexpr float kZero[4] = {}; const V4 vadd = hn::Load(d4, args.add ? args.add + col_c : kZero); @@ -219,18 +219,24 @@ class MMStoreHorizontalSumsIntoC { } }; // MMStoreHorizontalSumsIntoC -// Stateless, wraps member functions. +// Stateless, wraps member functions. Contains the innermost 2-4 loops. class MMKernel { + // Compute size of per-worker storage for `kNR` row ranges of B. Stack + // allocation avoids passing a worker index. + static constexpr size_t B_stride_max = + kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); + public: // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. - template + // Called by B3A2C0 and by callers that hoist `A_view`. + template static HWY_INLINE void A2C0(const StridedViewBF A_view, const StridedViewBF B_view, size_t mr, const IndexRange& range_mc, const size_t row_b, size_t kc, Tag tag, const MMArgs& args, - RowPtrs C_rows) { + CRows C_rows) { HWY_DASSERT(1 <= mr && mr <= kMaxMR); const size_t row0 = range_mc.begin(); const size_t mc = range_mc.Num(); @@ -280,6 +286,90 @@ class MMKernel { HWY_DASSERT(imc == mc); } + static constexpr size_t B_storage_max = kNR * B_stride_max; + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + + // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, + // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` + // thanks to its large table lookups, and less so on other targets. + template + static StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, + const IndexRange& range_kc, + const StridedViewBF B_view) { + const hn::ScalableTag dbf; + HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); + + // Neither A nor B require padding because `LoopKC` handles remainders. + if constexpr (hwy::IsSame()) { + return View(B, row_b, range_kc.begin(), range_kc.Num()); + } + + const PackedSpan B_span = B.PaddedSpan(); + + const size_t kc = range_kc.Num(); + const size_t col0 = range_kc.begin(); + + for (size_t r = 0; r < kNR; ++r) { + const size_t packed_ofs = (row_b + r) * B.Stride() + col0; + BF16* HWY_RESTRICT to = B_view.Row(r); + DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); + // Verify that we zero-padded. + if constexpr (HWY_IS_DEBUG_BUILD) { + for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { + HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); + } + } + } + return B_view; + } + + // Loop over NC/MC/KC, called from the outer loops. The MOMMS B3A2C0 reads + // `mc x kc` of A, `nc x kc` of B, and updates `mc x nc` of C. Called by + // `ForeachKC` and when there is only a single KC task. + template + static void B3A2C0(const StridedViewBF A, const MatPtrT& B, + const MMArgs& args, const IndexRange& range_mc, + const IndexRange& range_kc, const IndexRange& range_nc, + size_t mr, Tag out_tag, CRows C_rows) { + HWY_ALIGN BF16 B_storage[B_storage_max]; + + const size_t kc = range_kc.Num(); + const StridedViewBF A_view = A.View(range_mc.begin(), range_kc.begin(), kc); + + const size_t B_stride = + Stride(MatPadding::kOdd, kc, sizeof(BF16), args.line_bytes); + const StridedViewBF B_storage_view(B_storage, kc, B_stride); + + for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); + row_b += kNR) { + StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); + A2C0(A_view, B_view, mr, range_mc, row_b, kc, out_tag, args, C_rows); + } + } + + template + static void ForeachKC(const StridedViewBF A, const MatPtrT& B, + const MMArgs& args, const IndexRange& range_mc, + const IndexRangePartition& ranges_kc, + const IndexRange& range_nc, size_t mr, CRows C_rows) { + // Peel off the first iteration of the kc loop: avoid zero-initializing `C` + // by writing directly into it, and later accumulating into it. + ranges_kc.VisitFirst([&](const IndexRange& range_kc) { + B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMSetC(), C_rows); + }); + ranges_kc.VisitRemaining([&](const IndexRange& range_kc) { + B3A2C0(A, B, args, range_mc, range_kc, range_nc, mr, MMAddC(), C_rows); + }); + } + private: // Element-wise multiplies a vector from one row of A with `kNR` vectors, // each from a row of transposed B, and adds them to `kNR` fp32 `Cc` @@ -372,11 +462,11 @@ class MMKernel { // from range_mc-relative `imc` and `B_view` from row 0 (both at column 0). // Updates a `kRowsAC x kNR` tile with top-left `C.Row(row_ac) + col_c`. // `A` and `B` are always BF16, `C` can be F32 or BF16. - template + template static HWY_INLINE void LoopKC(const StridedViewBF A_view, const StridedViewBF B_view, size_t row_ac, size_t imc, size_t col_c, size_t kc, Tag tag, - const MMArgs& args, RowPtrs C_rows) { + const MMArgs& args, CRows C_rows) { const hn::ScalableTag dbf; using VBF = hn::Vec; HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); @@ -601,7 +691,7 @@ class MMImpl { size_t vector_bytes, MatMulEnv::PerCluster& per_cluster) { const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys); + intptr_t index = IndexOfKey(key, per_cluster.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { per_cluster.keys.Append(key, vector_bytes); @@ -614,9 +704,9 @@ class MMImpl { return per_cluster.per_key[index]; } - static void NotifyAutotuneResult(size_t M, size_t K, size_t N, double t0, - const MMConfig& cfg, MatMulEnv& env, - MMAutoTune& tuner) { + static void NotifyAutotuneResult(MatMulEnv& env, size_t M, size_t K, size_t N, + double t0, MMAutoTune& tuner, + const MMConfig& cfg) { const uint64_t t1 = env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const double min_elapsed = static_cast(tuner.NotifyTicks(t1 - t0)) / @@ -653,39 +743,16 @@ class MMImpl { } } - static size_t Worker(const MMArgs& args) { - return args.options.cluster_idx * - args.env->ctx.pools.MaxWorkersPerCluster(); - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - - template - static void DispatchParallelism(ParallelismStrategy parallelism, - const Func& func) { - switch (parallelism) { - case ParallelismStrategy::kHierarchical: - return func(MMParallelHierarchical()); - case ParallelismStrategy::kNone: - return func(MMParallelNone()); - case ParallelismStrategy::kWithinCluster: - return func(MMParallelWithinCluster()); - default: - HWY_UNREACHABLE; - } + static size_t Worker(const MatMulEnv& env, size_t cluster_idx) { + return cluster_idx * env.ctx.pools.MaxWorkersPerCluster(); } // Decompresses all `M x K` from `A` into padded BF16 `A_view`. static HWY_NOINLINE void DoDecompressA(const MatPtrT& A, const StridedViewBF A_view, - MMParA par_a, const MMArgs& args) { + MMAutoTune& autotune, + MMParA par_a, const MatMulEnv& env, + const MMOptions& options) { const IndexRange all_M(0, A.Rows()); const IndexRange all_K(0, A.Cols()); HWY_DASSERT(all_K.Num() == A_view.Cols()); @@ -693,13 +760,13 @@ class MMImpl { const hn::ScalableTag dbf; const size_t NBF = hn::Lanes(dbf); - static const auto zone = args.env->ctx.profiler.AddZone("MM.DecompressA"); + static const auto zone = env.ctx.profiler.AddZone("MM.DecompressA"); const auto do_range = [&](const IndexRange& range_M, const IndexRange& range_K, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args); + mm_zone.MaybeEnter(worker, zone, env, &autotune); const size_t col0 = range_K.begin(); const size_t cols = range_K.Num(); @@ -722,7 +789,7 @@ class MMImpl { switch (par_a) { case MMParA::kNone: - do_range(all_M, all_K, MMImpl::Worker(args)); + do_range(all_M, all_K, Worker(env, options.cluster_idx)); break; case MMParA::kK1: @@ -732,27 +799,26 @@ class MMImpl { // At least one vector, otherwise DecompressAndZeroPad will add // padding, which might overwrite neighboring tasks. Also a whole cache // line to avoid false sharing. - const size_t multiple_K = HWY_MAX(NBF, args.line_bytes / sizeof(BF16)); + const size_t multiple_K = + HWY_MAX(NBF, env.ctx.cache_info.LineBytes() / sizeof(BF16)); - DispatchParallelism( - args.options.parallelism, [&](const auto& parallel) { - parallel.ForN(args.env->ctx, all_K, multiple_K, inner_tasks, - args.options.cluster_idx, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); - }); + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForN(env.ctx, all_K, multiple_K, inner_tasks, + options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); + }); break; } case MMParA::kM: - DispatchParallelism( - args.options.parallelism, [&](const auto& parallel) { - parallel.ForRangeMC( - args.env->ctx, all_M, args.options.cluster_idx, - [&](size_t row_a, size_t worker) { - do_range(IndexRange(row_a, row_a + 1), all_K, worker); - }); - }); + DispatchParallelism(options.parallelism, [&](const auto& parallel) { + parallel.ForRangeMC(env.ctx, all_M, options.cluster_idx, + [&](size_t row_a, size_t worker) { + do_range(IndexRange(row_a, row_a + 1), all_K, + worker); + }); + }); break; } } @@ -760,11 +826,11 @@ class MMImpl { // Autotuning wrapper for `DoDecompressA`. static HWY_INLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view, - const MMArgs& args) { - MMAutoTune& autotune = args.per_key->autotune_par_a; - + MMAutoTune& autotune, + const MatMulEnv& env, + const MMOptions& options) { if (HWY_LIKELY(autotune.Best())) { - return DoDecompressA(A, A_view, *autotune.Best(), args); + return DoDecompressA(A, A_view, autotune, *autotune.Best(), env, options); } // First call: generate candidates. @@ -777,11 +843,11 @@ class MMImpl { const MMParA& par_a = autotune.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - DoDecompressA(A, A_view, par_a, args); + DoDecompressA(A, A_view, autotune, par_a, env, options); const uint64_t t1 = - args.env->have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); const uint64_t min_elapsed = autotune.NotifyTicks(t1 - t0); - if (HWY_UNLIKELY(args.env->print_measurement && autotune.ShouldPrint())) { + if (HWY_UNLIKELY(env.print_measurement && autotune.ShouldPrint())) { fprintf(stderr, "%s,%7.3f\n", StringFromParA(par_a), static_cast(min_elapsed) / hwy::platform::InvariantTicksPerSecond() * 1E6); @@ -790,299 +856,148 @@ class MMImpl { template static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, - const MMArgs& args) { + MMAutoTune& autotune, + const MatMulEnv& env, + MMOptions options) { if constexpr (IsBF16()) { // We can use a view, regardless of columns/padding, because `LoopKC` // supports non-vector multiples. - return View(A, 0, 0, A.Cols()); + return MMKernel::View(A, 0, 0, A.Cols()); } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. We also only // have a single MMStorage. - HWY_ASSERT(args.options.cluster_idx == 0); - const StridedViewBF A_view = args.env->storage.A(A.Extents()); - DecompressA(A, A_view, args); + HWY_ASSERT(options.cluster_idx == 0); + const StridedViewBF A_view = env.storage.A(A.Extents()); + DecompressA(A, A_view, autotune, env, options); return A_view; } } }; -// Contains several variants of the outer M/N/K loops, and calls `A2C0` which -// loops over the inner KC and MC. Member variables avoid long argument lists. -class MMState { +// Defines several variants of the outer M/N/K loops (see `MMOrder`). +class MMLoops { public: - MMState(size_t M, size_t K, size_t N, const MMArgs& args, - const MMConfig& config) - : args_(args), - range_n_(0, N), - mr_(config.MR()), - ranges_mc_(config.RangesOfMC(M)), - ranges_kc_(config.RangesOfKC(K)), - ranges_nc_(config.RangesOfNC(N)), - order_(config.Order()), - inner_tasks_(config.InnerTasks()) {} - // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. template - HWY_NOINLINE void DispatchParallelism(const StridedViewBF A, - const MatPtrT& B, - RowPtrs C_rows) const { - static const auto zone = - args_.env->ctx.profiler.AddZone("MM.DispatchParallelism"); - PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone); + static HWY_NOINLINE void Dispatch(const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.Dispatch"); + PROFILER_ZONE3(args.env.ctx.profiler, + MMImpl::Worker(args.env, args.options.cluster_idx), zone); - MMImpl::DispatchParallelism( - args_.options.parallelism, - [&](const auto& parallel) { DispatchOrder(parallel, A, B, C_rows); }); + DispatchParallelism( + args.options.parallelism, [&](const auto& parallel) HWY_ATTR { + DispatchOrder(args.order, [&](const auto& order) HWY_ATTR { + Loop(order, parallel, A, B, C_rows, args); + }); + }); } private: - // Compute size of per-worker storage for `kNR` row ranges of B. Stack - // allocation avoids passing a worker index. - static constexpr size_t B_stride_max_ = - kMaxKC + 2 * CacheInfo::MaxLineBytes() / sizeof(BF16); - static constexpr size_t B_storage_max_ = kNR * B_stride_max_; - // Granularity of `ForN`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. - size_t MultipleN(size_t sizeof_TC) const { - return HWY_MAX(kNR, args_.line_bytes / sizeof_TC); - } - - // B is decompressed several call layers lower, but not all member functions - // depend on `TB`, so pass it as an argument instead of templating the class. - template - HWY_NOINLINE void DispatchOrder(const ParallelT& parallel_policy, - const StridedViewBF A, const MatPtrT& B, - RowPtrs C_rows) const { - switch (order_) { - case MMOrder::kNT: - return DoNT(parallel_policy, A, B, C_rows); - case MMOrder::kNT_K: - return DoNT_K(parallel_policy, A, B, C_rows); - case MMOrder::kNT_MT: - return DoNT_MT(parallel_policy, A, B, C_rows); - case MMOrder::kNT_MT_K: - return DoNT_MT_K(parallel_policy, A, B, C_rows); - default: - HWY_UNREACHABLE; - } + static size_t MultipleN(size_t sizeof_TC, size_t line_bytes) { + return HWY_MAX(kNR, line_bytes / sizeof_TC); } // Single M and K ranges, parallel N. Fills all of C directly. - template - HWY_INLINE void DoNT(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT"); - HWY_DASSERT(ranges_mc_.NumTasks() == 1); - HWY_DASSERT(ranges_kc_.NumTasks() == 1); - const IndexRange& range_M = ranges_mc_.Range(0); - const IndexRange& range_K = ranges_kc_.Range(0); + template + static HWY_INLINE void Loop(MMOrderNT, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT"); + HWY_DASSERT(args.ranges_mc.NumTasks() == 1); + HWY_DASSERT(args.ranges_kc.NumTasks() == 1); + const IndexRange& range_M = args.ranges_mc.Range(0); + const IndexRange& range_K = args.ranges_kc.Range(0); const size_t K = range_K.Num(); const StridedViewBF A_view = A.View(range_M.begin(), 0, K); const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); + Stride(MatPadding::kOdd, K, sizeof(BF16), args.line_bytes); - // Similar to `loop_nc` below, but here we hoisted `A_view`. + // Similar to `B3A2C0`, but here we hoisted `A_view`. parallel.ForN( - args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, - args_.options.cluster_idx, + args.env.ctx, args.range_n, MultipleN(sizeof(TC), args.line_bytes), + args.inner_tasks, args.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS + HWY_ALIGN BF16 B_storage[MMKernel::B_storage_max]; // TLS const StridedViewBF B_storage_view(B_storage, K, B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { StridedViewBF B_view = - DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_M, row_b, K, MMSetC(), - args_, C_rows); + MMKernel::DecompressB(B, row_b, range_K, B_storage_view); + MMKernel::A2C0(A_view, B_view, args.mr, range_M, row_b, K, MMSetC(), + args, C_rows); } }); } // Single M range, parallel N, sequential K. Sets C, then accumulates. - template - HWY_INLINE void DoNT_K(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_K"); - HWY_DASSERT(ranges_mc_.NumTasks() == 1); - const IndexRange& range_mc = ranges_mc_.Range(0); + template + static HWY_INLINE void Loop(MMOrderNT_K, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_K"); + HWY_DASSERT(args.ranges_mc.NumTasks() == 1); + const IndexRange& range_mc = args.ranges_mc.Range(0); - // Loop over NC/MC/KC, called from the outer loops over K/N. - // C++14 generic lambda enables hoisting branches via template - // argument, while also capturing to avoid long argument lists. - const auto loop_nc = [&](BF16* B_storage, const IndexRange& range_kc, - const IndexRange& range_nc, - auto out_tag) HWY_ATTR { - const size_t kc = range_kc.Num(); - const StridedViewBF A_view = - A.View(range_mc.begin(), range_kc.begin(), kc); - const StridedViewBF B_storage_view( - B_storage, kc, - Stride(MatPadding::kOdd, kc, sizeof(BF16), args_.line_bytes)); - - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); - } - }; - - parallel.ForN( - args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, - args_.options.cluster_idx, - [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { - MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - - // Peel off the first iteration of the kc loop: avoid - // zero-initializing `partial` by writing into it. - ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMSetC()); - }); - ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage, range_kc, range_nc, MMAddC()); - }); - }); + parallel.ForN(args.env.ctx, args.range_n, + MultipleN(sizeof(TC), args.line_bytes), args.inner_tasks, + args.options.cluster_idx, + [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { + MMZone mm_zone; + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, + range_nc, args.mr, C_rows); + }); } // Parallel loops over mc/nc blocks of M/range_n, single K. // Fills `mc x nc` sections of C directly, in parallel. - template - HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT"); - HWY_DASSERT(ranges_kc_.NumTasks() == 1); - const IndexRange& range_K = ranges_kc_.Range(0); - const size_t K = range_K.Num(); - const size_t B_stride = - Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); + template + static HWY_INLINE void Loop(MMOrderNT_MT, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT"); + HWY_DASSERT(args.ranges_kc.NumTasks() == 1); + const IndexRange& range_K = args.ranges_kc.Range(0); - // Similar to `loop_nc` below except for the profiler zone and `MMSetC`. parallel.ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, + args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - const StridedViewBF A_view = A.View(range_mc.begin(), 0, K); - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const StridedViewBF B_storage_view(B_storage, K, B_stride); - - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - const StridedViewBF B_view = - DecompressB(B, row_b, range_K, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, K, MMSetC(), - args_, C_rows); - } + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::B3A2C0(A, B, args, range_mc, range_K, range_nc, args.mr, + MMSetC(), C_rows); }); } // Parallel loops over mc/nc blocks of M/range_np, sequential K. // Accumulates into `mc x nc` sections of `C`. - template - HWY_INLINE void DoNT_MT_K(ParallelT parallel, const StridedViewBF A, - const MatPtrT& B, RowPtrs C_rows) const { - static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); - const size_t kc_max = ranges_kc_.TaskSize(); - HWY_DASSERT(kc_max <= kMaxKC); - const size_t B_stride = - Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); - // Sequential loop over NC/MC/KC, for when the M/N loops are - // already parallel. This is B3A2C0 in MOMMS terminology: we read - // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `C`. - const auto loop_nc = [&](const StridedViewBF B_storage_view, - const IndexRange& range_mc, - const IndexRange& range_kc, - const IndexRange& range_nc, - auto out_tag) HWY_ATTR { - const size_t kc = range_kc.Num(); - const StridedViewBF A_view = - A.View(range_mc.begin(), range_kc.begin(), kc); + template + static HWY_INLINE void Loop(MMOrderNT_MT_K, Parallel parallel, + const StridedViewBF A, const MatPtrT& B, + RowPtrs C_rows, const MMArgs& args) { + static const auto zone = args.env.ctx.profiler.AddZone("MM.NT_MT_K"); - for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); - row_b += kNR) { - StridedViewBF B_view = DecompressB(B, row_b, range_kc, B_storage_view); - MMKernel::A2C0(A_view, B_view, mr_, range_mc, row_b, kc, out_tag, args_, - C_rows); - } - }; // loop_nc parallel.ForRangesMC_NC( - args_.env->ctx, ranges_mc_, ranges_nc_, args_.options.cluster_idx, + args.env.ctx, args.ranges_mc, args.ranges_nc, args.options.cluster_idx, [&](const IndexRange& range_mc, const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; - mm_zone.MaybeEnter(worker, zone, args_); - - HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const StridedViewBF B_storage_view(B_storage, kc_max, B_stride); - - // Peel off the first iteration of the kc loop: avoid - // zero-initializing `C` by writing into it. - ranges_kc_.VisitFirst([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMSetC()); - }); - ranges_kc_.VisitRemaining([&](const IndexRange& range_kc) { - loop_nc(B_storage_view, range_mc, range_kc, range_nc, MMAddC()); - }); + mm_zone.MaybeEnter(worker, zone, args.env, &args.autotune); + MMKernel::ForeachKC(A, B, args, range_mc, args.ranges_kc, range_nc, + args.mr, C_rows); }); } - - // Decompresses `kNR x kc` from `B[row_b, range_kc.begin()]` to row 0, - // col 0 of `B_view`. Decompressing SFP is relatively cheap on `AVX3_DL` - // thanks to its large table lookups, and less so on other targets. - template - HWY_INLINE StridedViewBF DecompressB(const MatPtrT& B, const size_t row_b, - const IndexRange& range_kc, - const StridedViewBF B_view) const { - const hn::ScalableTag dbf; - HWY_LANES_CONSTEXPR const size_t NBF = hn::Lanes(dbf); - - // Neither A nor B require padding because `LoopKC` handles remainders. - if constexpr (hwy::IsSame()) { - return MMImpl::View(B, row_b, range_kc.begin(), range_kc.Num()); - } - - const PackedSpan B_span = B.PaddedSpan(); - - const size_t kc = range_kc.Num(); - const size_t col0 = range_kc.begin(); - - for (size_t r = 0; r < kNR; ++r) { - const size_t packed_ofs = (row_b + r) * B.Stride() + col0; - BF16* HWY_RESTRICT to = B_view.Row(r); - DecompressAndZeroPad(dbf, B_span, packed_ofs, to, kc); - // Verify that we zero-padded. - if constexpr (HWY_IS_DEBUG_BUILD) { - for (size_t i = kc; i < hwy::RoundUpTo(kc, NBF); ++i) { - HWY_DASSERT(hwy::ConvertScalarTo(to[i]) == 0.0f); - } - } - } - return B_view; - } - - const MMArgs args_; // copy for locality - - const IndexRange range_n_; - // From MMConfig: - const size_t mr_; - const IndexRangePartition ranges_mc_; - const IndexRangePartition ranges_kc_; - const IndexRangePartition ranges_nc_; - const MMOrder order_; - const size_t inner_tasks_; -}; // MMState +}; // MMLoops // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // @@ -1109,29 +1024,30 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); + const size_t cluster_idx = options.cluster_idx; + HWY_DASSERT(cluster_idx < env.row_ptrs.size()); PROFILER_ZONE3(env.ctx.profiler, - options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), - zone); + cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), zone); - HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); - RowPtrs C_rows = - GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); + RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[cluster_idx]); const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); const CacheInfo& cache = env.ctx.cache_info; - MMPerKey& per_key = MMImpl::FindOrAddPerKey( - M, K, N, cache.VectorBytes(), env.per_cluster[options.cluster_idx]); - MMAutoTune& tuner = per_key.autotune; + MMPerKey& per_key = MMImpl::FindOrAddPerKey(M, K, N, cache.VectorBytes(), + env.per_cluster[cluster_idx]); - const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), - add, options); + // (Also auto-tunes, hence outside the timed section to prevent interference.) + const StridedViewBF A_view = + MMImpl::MaybeDecompressA(A, per_key.autotune_par_a, env, options); + + MMAutoTune& tuner = per_key.autotune; if (HWY_LIKELY(tuner.Best())) { - const MMState state(M, K, N, args, *tuner.Best()); - const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); - state.DispatchParallelism(A_view, B, C_rows); + const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), + add, options, tuner, *tuner.Best()); + MMLoops::Dispatch(A_view, B, C_rows, args); return &per_key; } @@ -1147,14 +1063,13 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMCandidates(cache, M, K, N, sizeof(TC), env.print_config)); } - // (Also auto-tunes, hence outside the timed section to prevent interference.) - const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); - const MMConfig& cfg = tuner.NextConfig(); + const MMArgs args(env, M, K, N, static_cast(A.Scale()) * B.Scale(), + add, options, tuner, cfg); + const uint64_t t0 = hwy::timer::Start(); - MMState state(M, K, N, args, cfg); - state.DispatchParallelism(A_view, B, C_rows); - MMImpl::NotifyAutotuneResult(M, K, N, t0, cfg, env, tuner); + MMLoops::Dispatch(A_view, B, C_rows, args); + MMImpl::NotifyAutotuneResult(env, M, K, N, t0, tuner, cfg); return &per_key; } diff --git a/ops/matmul.h b/ops/matmul.h index 946673a..915970c 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -21,7 +21,7 @@ #include #include -#include // std::unique_ptr +#include #include // IWYU pragma: begin_exports @@ -54,13 +54,58 @@ HWY_INLINE_VAR constexpr size_t kNR = 4; // or less on ISAs with fewer registers, or for the last few rows of A. HWY_INLINE_VAR constexpr size_t kMaxMR = 4; +HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; // TODO: shrink? + // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; +// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. +// Also used to decompress B, hence non-const. +#pragma pack(push, 1) // power of two size +template +class StridedView { + public: + StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) + : row0_(row0), + cols_(static_cast(cols)), + stride_(static_cast(stride)) { + HWY_DASSERT(stride >= cols); + } + + T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } + size_t Cols() const { return static_cast(cols_); } + + size_t Stride() const { return static_cast(stride_); } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + StridedView View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < Cols()); + HWY_DASSERT(cols <= Cols() - c); + return StridedView(Row(r) + c, cols, stride_); + } + + private: + T* HWY_RESTRICT row0_; + uint32_t cols_; + uint32_t stride_; +}; +#pragma pack(pop) + +using StridedViewBF = StridedView; +using StridedViewD = StridedView; + +using MMFused = std::function; + struct MMOptions { uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. ParallelismStrategy parallelism = ParallelismStrategy::kHierarchical; + + MMFused fused; }; // Policy classes for parallelism, implementing some of `ParallelismStrategy`. @@ -260,49 +305,26 @@ struct MMParallelHierarchical { } }; +template +void DispatchParallelism(ParallelismStrategy parallelism, const Func& func, + Args&&... args) { + switch (parallelism) { + case ParallelismStrategy::kNone: + return func(MMParallelNone(), std::forward(args)...); + case ParallelismStrategy::kWithinCluster: + return func(MMParallelWithinCluster(), std::forward(args)...); + case ParallelismStrategy::kHierarchical: + return func(MMParallelHierarchical(), std::forward(args)...); + default: + HWY_UNREACHABLE; + } +} + void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC); // C is BF16/float. void BindC(ThreadingContext& ctx, MatPtr& C); -// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows. -// Also used to decompress B, hence non-const. -#pragma pack(push, 1) // power of two size -template -class StridedView { - public: - StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), - cols_(static_cast(cols)), - stride_(static_cast(stride)) { - HWY_DASSERT(stride >= cols); - } - - T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; } - size_t Cols() const { return static_cast(cols_); } - - size_t Stride() const { return static_cast(stride_); } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - StridedView View(size_t r, size_t c, size_t cols) const { - HWY_DASSERT(c < Cols()); - HWY_DASSERT(cols <= Cols() - c); - return StridedView(Row(r) + c, cols, stride_); - } - - private: - T* HWY_RESTRICT row0_; - uint32_t cols_; - uint32_t stride_; -}; -#pragma pack(pop) - -using StridedViewBF = StridedView; -using StridedViewD = StridedView; - +// For A. class MMStorage { public: // Compile-time bounds on matrix columns to enable pre-allocating storage @@ -354,6 +376,28 @@ enum class MMOrder : uint8_t { // no kM* because we expect M (batch size) to be small relative to K and N. }; +// Tag types for `DispatchOrder`. +struct MMOrderNT_K {}; +struct MMOrderNT {}; +struct MMOrderNT_MT_K {}; +struct MMOrderNT_MT {}; + +template +void DispatchOrder(MMOrder order, const Func& func, Args&&... args) { + switch (order) { + case MMOrder::kNT_K: + return func(MMOrderNT_K(), std::forward(args)...); + case MMOrder::kNT: + return func(MMOrderNT(), std::forward(args)...); + case MMOrder::kNT_MT_K: + return func(MMOrderNT_MT_K(), std::forward(args)...); + case MMOrder::kNT_MT: + return func(MMOrderNT_MT(), std::forward(args)...); + default: + HWY_UNREACHABLE; + } +} + static inline bool IsBlock(MMOrder order) { return order == MMOrder::kNT_MT_K || order == MMOrder::kNT_MT; } @@ -693,26 +737,46 @@ struct MatMulEnv { std::vector> row_ptrs; }; -// Arguments to MatMul() that are independent of the A/B/C types. -// Reduces register pressure compared to individual values/references. +// Arguments to MatMul() that are independent of the A/B/C types. Reduces +// register pressure compared to individual values/references. Also used for +// passing through `DispatchOrder`. struct MMArgs { - MMArgs(MatMulEnv& env, MMPerKey& per_key, double scale, - const float* HWY_RESTRICT add, MMOptions options) - : env(&env), - per_key(&per_key), + MMArgs(MatMulEnv& env, size_t M, size_t K, size_t N, double scale, + const float* HWY_RESTRICT add, MMOptions options, + const MMAutoTune& autotune, const MMConfig& config) + : env(env), + line_bytes(env.ctx.cache_info.LineBytes()), + + range_n(0, N), scale(scale), add(add), options(options), - line_bytes(env.ctx.cache_info.LineBytes()) {} - MatMulEnv* env; - MMPerKey* per_key; + autotune(autotune), + mr(config.MR()), + ranges_mc(config.RangesOfMC(M)), + ranges_kc(config.RangesOfKC(K)), + ranges_nc(config.RangesOfNC(N)), + order(config.Order()), + inner_tasks(config.InnerTasks()) {} - double scale; + MatMulEnv& env; + const size_t line_bytes; // from `env`, for `Stride`. + + // MatMul arguments: + const IndexRange range_n; // entire N + const double scale; const float* HWY_RESTRICT add; + const MMOptions options; - MMOptions options; - size_t line_bytes; + const MMAutoTune& autotune; // for `MaybeEnter` + // From `MMConfig`: + const size_t mr; + const IndexRangePartition ranges_mc; + const IndexRangePartition ranges_kc; + const IndexRangePartition ranges_nc; + const MMOrder order; + const size_t inner_tasks; }; // Wrapper over hwy::Zone that is only enabled when autotuning finished. @@ -729,11 +793,12 @@ class MMZone { } } - // `name` must be a string literal. + template void MaybeEnter(size_t thread, hwy::profiler::ZoneHandle zone, - const MMArgs& args) { - if (args.per_key->WantProfile()) { - new (&data_) Zone(args.env->ctx.profiler, thread, zone); + const MatMulEnv& env, const AutoTune* auto_tune) { + // Only if enabled and autotuning finished. + if (PROFILER_ENABLED && auto_tune->Best()) { + new (&data_) Zone(env.ctx.profiler, thread, zone); HWY_DASSERT(data_ != 0); } } @@ -744,7 +809,8 @@ class MMZone { }; #else struct MMZone { - void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MMArgs&) {} + void MaybeEnter(size_t, hwy::profiler::ZoneHandle, const MatMulEnv&, + const void*) {} }; #endif // PROFILER_ENABLED diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 8be84ec..c8feda9 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -37,7 +37,6 @@ #include "compression/compress-inl.h" #include "ops/dot-inl.h" -#include "ops/matmul.h" #include "util/mat.h" // MatPtrT #include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/matvec/matvec-inl.h" diff --git a/util/mat.h b/util/mat.h index c084e81..c8a4617 100644 --- a/util/mat.h +++ b/util/mat.h @@ -40,9 +40,10 @@ class RowPtrs { public: RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs) {} - T* HWY_RESTRICT operator[](size_t row_idx) const { + T* HWY_RESTRICT Row(size_t row_idx) const { return HWY_RCAST_ALIGNED(T*, row_ptrs_[row_idx]); } + T* HWY_RESTRICT operator[](size_t row_idx) const { return Row(row_idx); } private: uint8_t** row_ptrs_;