mirror of https://github.com/google/gemma.cpp.git
Fix MatMul issue caused by autotuning bucketing, refs #608, thanks @ufownl
PiperOrigin-RevId: 771077158
This commit is contained in:
parent
01cdefeda7
commit
2c72ff2aa5
|
|
@ -285,6 +285,7 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":matmul",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -347,13 +348,6 @@ class MMAddHorizontalSumsIntoPartial {
|
|||
// Stateless, wraps member functions.
|
||||
class MMKernel {
|
||||
public:
|
||||
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
|
||||
// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile.
|
||||
// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions
|
||||
// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`,
|
||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||
static constexpr size_t kMaxMR = 4;
|
||||
|
||||
// Calls `LoopKC` for each of `mc` rows of A in steps of `mr`. `A_view`
|
||||
// is `mc x kc` and `B_view` is `(kNR x kc)`. Both start at row/col 0.
|
||||
// A2C0 in MOMMS terminology updates a `mc x kNR` slice of the output.
|
||||
|
|
@ -917,7 +911,7 @@ class MMPerPackage {
|
|||
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
|
||||
}
|
||||
|
||||
// Single M and K, parallel N. Fills all of C directly.
|
||||
// Single M and K ranges, parallel N. Fills all of C directly.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
|
|
@ -950,7 +944,7 @@ class MMPerPackage {
|
|||
HWY_DASSERT(out_ == MMOut::kDirect); // already filled C
|
||||
}
|
||||
|
||||
// Single M, parallel N, sequential K. Fills all of partial.
|
||||
// Single M range, parallel N, sequential K. Fills all of partial.
|
||||
template <typename TB, typename TC>
|
||||
HWY_INLINE void DoNT_K(const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
|
||||
MMZone zone;
|
||||
|
|
@ -1365,9 +1359,8 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
HWY_ASSERT(N % kNR == 0);
|
||||
|
||||
// Negligible CPU time.
|
||||
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC),
|
||||
MMKernel::kMaxMR, kNR, per_key.ranges_np,
|
||||
env.print_config));
|
||||
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
|
||||
kNR, per_key.ranges_np, env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class GenerateCandidates {
|
|||
// 2D blocking is useless for a single row of M.
|
||||
if (IsBlock(order) && M_ <= mr) continue;
|
||||
// Conversely, N-only parallelism is uncompetitive for large M.
|
||||
if (!IsBlock(order) && M_ >= 8 * mr) continue;
|
||||
if (!IsBlock(order) && M_ >= kMaxTilesM * mr) continue;
|
||||
orders.push_back(order);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
22
ops/matmul.h
22
ops/matmul.h
|
|
@ -46,6 +46,13 @@ namespace gcpp {
|
|||
// `MMAddHorizontalSumsIntoPartial`. We ensure `C.Cols() % kNR == 0`.
|
||||
constexpr size_t kNR = 4;
|
||||
|
||||
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
|
||||
// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile.
|
||||
// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions
|
||||
// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`,
|
||||
// or less on ISAs with fewer registers, or for the last few rows of A.
|
||||
static constexpr size_t kMaxMR = 4;
|
||||
|
||||
// Mostly stateless, can be constructed on the fly by weights.cc, but captures
|
||||
// the singleton ThreadingContext to reduce MatMul call overhead.
|
||||
class MMParallel {
|
||||
|
|
@ -558,16 +565,19 @@ class MMAutoTune {
|
|||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Minimum M, in units of tile rows of height mr={1, 2, 4}, from which
|
||||
// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range,
|
||||
// but choosing the same config for a larger M can result in multiple MC ranges.
|
||||
// Thus M less than this must have unique keys/configs.
|
||||
static constexpr size_t kMaxTilesM = 8;
|
||||
|
||||
// Map of previously seen dimensions to index via linear search.
|
||||
class MMKeys {
|
||||
// Group batch size into buckets to reduce #auto-tunes.
|
||||
static size_t BucketM(size_t M) {
|
||||
// The first 4 may require their own bucket because `kNT` only works for a
|
||||
// single M range, but that depends on the config's `MR()`.
|
||||
if (M <= 4) return M;
|
||||
if (M <= 16) return 16;
|
||||
if (M <= 64) return 64;
|
||||
return 256;
|
||||
if (M < kMaxTilesM * kMaxMR) return M; // See kMaxTilesM above.
|
||||
if (M <= 128) return 128;
|
||||
return 512;
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
|
|||
Loading…
Reference in New Issue