Fix MatMul issue caused by autotuning bucketing, refs #608, thanks @ufownl

PiperOrigin-RevId: 771077158
This commit is contained in:
Jan Wassenberg 2025-06-13 06:58:05 -07:00 committed by Copybara-Service
parent 01cdefeda7
commit 2c72ff2aa5
4 changed files with 23 additions and 19 deletions

View File

@ -285,6 +285,7 @@ cc_library(
deps = [ deps = [
":allocator", ":allocator",
":basics", ":basics",
":mat",
":matmul", ":matmul",
":threading_context", ":threading_context",
"//compression:compress", "//compression:compress",

View File

@ -23,6 +23,7 @@
#include "ops/matmul.h" // IWYU pragma: export #include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" #include "util/basics.h"
#include "util/mat.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -347,13 +348,6 @@ class MMAddHorizontalSumsIntoPartial {
// Stateless, wraps member functions. // Stateless, wraps member functions.
class MMKernel { class MMKernel {
public: public:
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile.
// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions
// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`,
// or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4;
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view` // Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0. // is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output. // A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
@ -917,7 +911,7 @@ class MMPerPackage {
return HWY_MAX(kNR, line_bytes_ / sizeof_TC); return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
} }
// Single M and K, parallel N. Fills all of C directly. // Single M and K ranges, parallel N. Fills all of C directly.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone; MMZone zone;
@ -950,7 +944,7 @@ class MMPerPackage {
HWY_DASSERT(out_ == MMOut::kDirect); // already filled C HWY_DASSERT(out_ == MMOut::kDirect); // already filled C
} }
// Single M, parallel N, sequential K. Fills all of partial. // Single M range, parallel N, sequential K. Fills all of partial.
template <typename TB, typename TC> template <typename TB, typename TC>
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const { HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
MMZone zone; MMZone zone;
@ -1365,9 +1359,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
HWY_ASSERT(N % kNR == 0); HWY_ASSERT(N % kNR == 0);
// Negligible CPU time. // Negligible CPU time.
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
MMKernel::kMaxMR, kNR, per_key.ranges_np, kNR, per_key.ranges_np, env.print_config));
env.print_config));
} }
const MMConfig& cfg = tuner.NextConfig(); const MMConfig& cfg = tuner.NextConfig();

View File

@ -153,7 +153,7 @@ class GenerateCandidates {
// 2D blocking is useless for a single row of M. // 2D blocking is useless for a single row of M.
if (IsBlock(order) && M_ <= mr) continue; if (IsBlock(order) && M_ <= mr) continue;
// Conversely, N-only parallelism is uncompetitive for large M. // Conversely, N-only parallelism is uncompetitive for large M.
if (!IsBlock(order) && M_ >= 8 * mr) continue; if (!IsBlock(order) && M_ >= kMaxTilesM * mr) continue;
orders.push_back(order); orders.push_back(order);
} }
} }

View File

@ -46,6 +46,13 @@ namespace gcpp {
// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`. // `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
constexpr size_t kNR = 4; constexpr size_t kNR = 4;
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile.
// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions
// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`,
// or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4;
// Mostly stateless, can be constructed on the fly by weights.cc, but captures // Mostly stateless, can be constructed on the fly by weights.cc, but captures
// the singleton ThreadingContext to reduce MatMul call overhead. // the singleton ThreadingContext to reduce MatMul call overhead.
class MMParallel { class MMParallel {
@ -558,16 +565,19 @@ class MMAutoTune {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// Minimum M, in units of tile rows of height mr={1, 2, 4}, from which
// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range,
// but choosing the same config for a larger M can result in multiple MC ranges.
// Thus M less than this must have unique keys/configs.
static constexpr size_t kMaxTilesM = 8;
// Map of previously seen dimensions to index via linear search. // Map of previously seen dimensions to index via linear search.
class MMKeys { class MMKeys {
// Group batch size into buckets to reduce #auto-tunes. // Group batch size into buckets to reduce #auto-tunes.
static size_t BucketM(size_t M) { static size_t BucketM(size_t M) {
// The first 4 may require their own bucket because `kNT` only works for a if (M < kMaxTilesM * kMaxMR) return M; // See kMaxTilesM above.
// single M range, but that depends on the config's `MR()`. if (M <= 128) return 128;
if (M <= 4) return M; return 512;
if (M <= 16) return 16;
if (M <= 64) return 64;
return 256;
} }
public: public: