diff --git a/BUILD.bazel b/BUILD.bazel index 0e4df32..3b894be 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -285,6 +285,7 @@ cc_library( deps = [ ":allocator", ":basics", + ":mat", ":matmul", ":threading_context", "//compression:compress", diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 806deda..4d3efd0 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -23,6 +23,7 @@ #include "ops/matmul.h" // IWYU pragma: export #include "util/allocator.h" #include "util/basics.h" +#include "util/mat.h" #include "util/threading_context.h" #include "hwy/base.h" #include "hwy/profiler.h" @@ -347,13 +348,6 @@ class MMAddHorizontalSumsIntoPartial { // Stateless, wraps member functions. class MMKernel { public: - // Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because - // we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile. - // In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions - // that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`, - // or less on ISAs with fewer registers, or for the last few rows of A. - static constexpr size_t kMaxMR = 4; - // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. @@ -917,7 +911,7 @@ class MMPerPackage { return HWY_MAX(kNR, line_bytes_ / sizeof_TC); } - // Single M and K, parallel N. Fills all of C directly. + // Single M and K ranges, parallel N. Fills all of C directly. template HWY_INLINE void DoNT(const MatPtrT& B, RowPtrs C_rows) const { MMZone zone; @@ -950,7 +944,7 @@ class MMPerPackage { HWY_DASSERT(out_ == MMOut::kDirect); // already filled C } - // Single M, parallel N, sequential K. Fills all of partial. + // Single M range, parallel N, sequential K. Fills all of partial. template HWY_INLINE void DoNT_K(const MatPtrT& B, RowPtrs C_rows) const { MMZone zone; @@ -1365,9 +1359,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, HWY_ASSERT(N % kNR == 0); // Negligible CPU time. - tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), - MMKernel::kMaxMR, kNR, per_key.ranges_np, - env.print_config)); + tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, + kNR, per_key.ranges_np, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index 7e83830..3da5512 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -153,7 +153,7 @@ class GenerateCandidates { // 2D blocking is useless for a single row of M. if (IsBlock(order) && M_ <= mr) continue; // Conversely, N-only parallelism is uncompetitive for large M. - if (!IsBlock(order) && M_ >= 8 * mr) continue; + if (!IsBlock(order) && M_ >= kMaxTilesM * mr) continue; orders.push_back(order); } } diff --git a/ops/matmul.h b/ops/matmul.h index bb35fa2..3297fcc 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -46,6 +46,13 @@ namespace gcpp { // `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`. constexpr size_t kNR = 4; +// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because +// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile. +// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions +// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`, +// or less on ISAs with fewer registers, or for the last few rows of A. +static constexpr size_t kMaxMR = 4; + // Mostly stateless, can be constructed on the fly by weights.cc, but captures // the singleton ThreadingContext to reduce MatMul call overhead. class MMParallel { @@ -558,16 +565,19 @@ class MMAutoTune { //------------------------------------------------------------------------------ +// Minimum M, in units of tile rows of height mr={1, 2, 4}, from which +// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range, +// but choosing the same config for a larger M can result in multiple MC ranges. +// Thus M less than this must have unique keys/configs. +static constexpr size_t kMaxTilesM = 8; + // Map of previously seen dimensions to index via linear search. class MMKeys { // Group batch size into buckets to reduce #auto-tunes. static size_t BucketM(size_t M) { - // The first 4 may require their own bucket because `kNT` only works for a - // single M range, but that depends on the config's `MR()`. - if (M <= 4) return M; - if (M <= 16) return 16; - if (M <= 64) return 64; - return 256; + if (M < kMaxTilesM * kMaxMR) return M; // See kMaxTilesM above. + if (M <= 128) return 128; + return 512; } public: