From 4be479972758b9064db24d1b629c23c5fc227f97 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Sep 2025 03:32:35 -0700 Subject: [PATCH] Remove kMaxPackages and per-package-related code matmul: remove kMaxClusters, dynamic allocation PiperOrigin-RevId: 802950348 --- gemma/activations.h | 8 +-- ops/dot_test.cc | 1 - ops/matmul-inl.h | 116 +++++++++++++++++++--------------- ops/matmul.cc | 121 ++++++++++------------------------- ops/matmul.h | 133 ++++++++++++++++----------------------- ops/matmul_test.cc | 41 +++++------- util/allocator.cc | 5 +- util/allocator.h | 2 +- util/basics.h | 7 --- util/mat.h | 1 + util/threading_context.h | 12 ++-- 11 files changed, 177 insertions(+), 270 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index cd714ae..63b3153 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -21,7 +21,6 @@ #include #include -#include #include #include "gemma/configs.h" // ModelConfig @@ -62,11 +61,12 @@ struct AttentionActivations { // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. static inline float ChooseQueryScale(const ModelConfig& config) { + const LayerConfig& layer_config = config.layer_configs[0]; if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / sqrtf(static_cast(config.model_dim / - config.layer_configs[0].heads)); + return 1.0f / + sqrtf(static_cast(config.model_dim / layer_config.heads)); // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); + return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); } AttentionActivations( diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 2c0ae3a..8afb220 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -1101,7 +1101,6 @@ void TestAllDot() { // Limit workers because we only support `kMaxWorkers`. ThreadingArgs threading_args; - threading_args.max_packages = 1; threading_args.max_clusters = 1; threading_args.max_lps = kMaxWorkers - 1; ThreadingContext ctx(threading_args); diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index c915b14..8b9c011 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -570,10 +570,29 @@ struct MMImpl { // Returns existing entry for the given key or -1. static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) { const hwy::Span all_keys = keys.Keys(); - // TODO: SIMD scan - for (size_t i = 0; i < all_keys.size(); ++i) { - if (all_keys[i] == key) return static_cast(i); + + const hn::ScalableTag d; + using V = hn::Vec; + const V broadcasted = Set(d, key); + const size_t N = hn::Lanes(d); + + size_t i = 0; + if (all_keys.size() >= N) { + for (; i <= all_keys.size() - N; i += N) { + const intptr_t pos = hn::FindFirstTrue( + d, hn::Eq(broadcasted, hn::LoadU(d, &all_keys[i]))); + if (pos >= 0) return static_cast(i) + pos; + } } + + const size_t remaining = all_keys.size() - i; + if (HWY_LIKELY(remaining > 0)) { + HWY_DASSERT(remaining < N); + const V v = hn::LoadN(d, &all_keys[i], remaining); + const intptr_t pos = hn::FindFirstTrue(d, hn::Eq(broadcasted, v)); + if (pos >= 0) return static_cast(i) + pos; + } + return -1; } @@ -582,6 +601,15 @@ struct MMImpl { args.env->ctx.pools.MaxWorkersPerCluster(); } + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + template + static StridedView View(const MatPtrT& AB, size_t r, size_t c, + size_t cols) { + HWY_DASSERT(c < AB.Cols()); + HWY_DASSERT(cols <= AB.Cols() - c); + return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); + } + template static void DispatchParallelism(ParallelismStrategy parallelism, const Func& func) { @@ -651,11 +679,11 @@ struct MMImpl { DispatchParallelism( args.options.parallelism, [&](const auto& parallel) { - parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks, - args.options.cluster_idx, - [&](const IndexRange& range_K, size_t worker) { - do_range(all_M, range_K, worker); - }); + parallel.ForN(args.env->ctx, all_K, multiple_K, inner_tasks, + args.options.cluster_idx, + [&](const IndexRange& range_K, size_t worker) { + do_range(all_M, range_K, worker); + }); }); break; } @@ -676,7 +704,7 @@ struct MMImpl { static HWY_INLINE void DecompressA(const MatPtrT& A, const StridedViewBF A_view, const MMArgs& args) { - MMAutoTune& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0]; + MMAutoTune& autotune = args.per_key->autotune_par_a; if (HWY_LIKELY(autotune.Best())) { return DoDecompressA(A, A_view, *autotune.Best(), args); @@ -703,15 +731,6 @@ struct MMImpl { } } - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - template - static StridedView View(const MatPtrT& AB, size_t r, size_t c, - size_t cols) { - HWY_DASSERT(c < AB.Cols()); - HWY_DASSERT(cols <= AB.Cols() - c); - return StridedView(const_cast(AB.Row(r)) + c, cols, AB.Stride()); - } - template static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT& A, const MMArgs& args) { @@ -723,8 +742,7 @@ struct MMImpl { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. const StridedViewBF A_view = - args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0, - A.Extents()); + args.env->storage[args.options.cluster_idx].A(A.Extents()); DecompressA(A, A_view, args); return A_view; } @@ -735,17 +753,16 @@ struct MMImpl { // loops over the inner KC and MC. Member variables avoid long argument lists. class MMState { public: - MMState(const Extents2D A, const MMArgs& args, const MMConfig& config) + MMState(const Extents2D A, const size_t B_rows, const MMArgs& args, + const MMConfig& config) : args_(args), - range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)), + range_n_(0, B_rows), mr_(config.MR()), ranges_mc_(config.RangesOfMC(A.rows)), ranges_kc_(config.RangesOfKC(A.cols)), - ranges_nc_(config.RangesOfNC(range_np_)), + ranges_nc_(config.RangesOfNC(B_rows)), order_(config.Order()), - inner_tasks_(config.InnerTasks()) { - HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1); - } + inner_tasks_(config.InnerTasks()) {} // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. @@ -768,12 +785,12 @@ class MMState { // Compute size of per-worker storage for `kNR` row ranges of B. Stack // allocation avoids passing a worker index. static constexpr size_t B_stride_max_ = - MMStorage::kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16); + kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16); static constexpr size_t B_storage_max_ = kNR * B_stride_max_; - // Granularity of `ForNP`. B rows produce C columns, so we + // Granularity of `ForN`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. - size_t MultipleNP(size_t sizeof_TC) const { + size_t MultipleN(size_t sizeof_TC) const { return HWY_MAX(kNR, args_.line_bytes / sizeof_TC); } @@ -812,8 +829,8 @@ class MMState { Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes); // Similar to `loop_nc` below, but here we hoisted `A_view`. - parallel.ForNP( - args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + parallel.ForN( + args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -861,8 +878,8 @@ class MMState { } }; - parallel.ForNP( - args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_, + parallel.ForN( + args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_, args_.options.cluster_idx, [&](const IndexRange& range_nc, size_t worker) HWY_ATTR { MMZone mm_zone; @@ -881,7 +898,7 @@ class MMState { }); } - // Parallel loops over mc/nc blocks of M/range_np, single K. + // Parallel loops over mc/nc blocks of M/range_n, single K. // Fills `mc x nc` sections of C directly, in parallel. template HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A, @@ -923,7 +940,7 @@ class MMState { const MatPtrT& B, RowPtrs C_rows) const { static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K"); const size_t kc_max = ranges_kc_.TaskSize(); - HWY_DASSERT(kc_max <= MMStorage::kMaxKC); + HWY_DASSERT(kc_max <= kMaxKC); const size_t B_stride = Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes); // Sequential loop over NC/MC/KC, for when the M/N loops are @@ -1002,7 +1019,7 @@ class MMState { const MMArgs args_; // copy for locality - const IndexRange range_np_; + const IndexRange range_n_; // From MMConfig: const size_t mr_; const IndexRangePartition ranges_mc_; @@ -1036,38 +1053,33 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { + const Allocator& allocator = env.ctx.allocator; HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); + MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx]; RowPtrs C_rows = GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]); - const Allocator& allocator = env.ctx.allocator; const size_t M = A.Rows(); const size_t K = A.Cols(); const size_t N = B.Rows(); const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N); - intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]); + intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { - env.keys[options.cluster_idx].Append(key, allocator); - - size_t max_packages = kMaxPackages; - // For low-batch, multiple sockets only help if binding is enabled. - if (!allocator.ShouldBind() && M <= 4) { - max_packages = 1; - } + per_cluster.keys.Append(key, allocator); // invalidates `MMAutoTune::Best()` - std::vector& stored_keys = env.per_key[options.cluster_idx]; - index = stored_keys.size(); - stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR)); + std::vector& per_keys = per_cluster.per_key; + index = per_keys.size(); + per_keys.push_back(MMPerKey()); } - MMPerKey& per_key = env.per_key[options.cluster_idx][index]; + MMPerKey& per_key = per_cluster.per_key[index]; MMAutoTune& tuner = per_key.autotune; const MMArgs args(env, per_key, static_cast(A.Scale()) * B.Scale(), add, options); if (HWY_LIKELY(tuner.Best())) { - const MMState state(A.Extents(), args, *tuner.Best()); + const MMState state(A.Extents(), B.Rows(), args, *tuner.Best()); const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); state.DispatchParallelism(A_view, B, C_rows); return &per_key; @@ -1092,12 +1104,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, } tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR, - kNR, per_key.ranges_np, env.print_config)); + kNR, env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); const uint64_t t0 = hwy::timer::Start(); - MMState state(A.Extents(), args, cfg); + MMState state(A.Extents(), B.Rows(), args, cfg); const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args); state.DispatchParallelism(A_view, B, C_rows); const uint64_t t1 = diff --git a/ops/matmul.cc b/ops/matmul.cc index 75b37a2..35887a5 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -21,7 +21,6 @@ #include #include -#include #include #include "util/allocator.h" @@ -65,7 +64,7 @@ class GenerateCandidates { public: GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, size_t sizeof_TC, size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, bool print_config) + bool print_config) : allocator_(allocator), M_(M), K_(K), @@ -79,7 +78,6 @@ class GenerateCandidates { // up to the line size. Both A and B are BF16. kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), nc_multiple_(allocator.StepBytes() / sizeof_TC), - ranges_np_(ranges_np), print_config_(print_config) {} std::vector operator()() const { @@ -177,8 +175,7 @@ class GenerateCandidates { allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream)); const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); - kc_max = - RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_); + kc_max = RoundDownWithFloor(HWY_MIN(kc_max, kMaxKC), kc_multiple_); kc_max = HWY_MIN(kc_max, K_); SizeVec all_kc(1, kc_max); @@ -258,32 +255,30 @@ class GenerateCandidates { // The number of (possibly L3 resident) B rows per `NT_MT` task. SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const { - const size_t np_max = ranges_np_.TaskSize(); - size_t nc_max = np_max; + size_t nc_max = N_; // Only if there will be reuse of B: choose the largest `nc_max` (C cols) // such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3. // Otherwise, leave it unbounded. if (M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_); - nc_max = - HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), np_max); + nc_max = HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), N_); } HWY_DASSERT(nc_max != 0); nc_max = RoundDownWithFloor(nc_max, nc_multiple_); // If there are going to be multiple ranges, anything more than half would // be imbalanced and suboptimal. - if (nc_max < np_max && nc_max >= np_max / 2) { - nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_); + if (nc_max < N_ && nc_max >= N_ / 2) { + nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_); } // Non-block calls ForNP, which ignores `range_nc` and uses `range_np`. - if (!IsBlock(order)) return SizeVec(1, np_max); + if (!IsBlock(order)) return SizeVec(1, N_); SizeVec all_nc(1, nc_max); // Avoid proposing nc > N. - if (np_max > nc_multiple_) { + if (N_ > nc_multiple_) { // Large L3, but its behavior and characteristics varies across platforms, // hence autotune a wider range of nc than the other dimensions. size_t reps = 10; @@ -292,8 +287,7 @@ class GenerateCandidates { size_t prev = nc_max; for (size_t rep = 0; rep < reps; ++rep) { - const size_t div = - PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_); + const size_t div = PrevDivisor(nc_multiple_, prev, N_, nc_multiple_); prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_); all_nc.push_back(prev); if (prev == nc_multiple_) break; @@ -346,8 +340,6 @@ class GenerateCandidates { const size_t kc_multiple_; const size_t nc_multiple_; - IndexRangePartition ranges_np_; - const bool print_config_; }; @@ -357,58 +349,19 @@ class GenerateCandidates { std::vector MMCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, size_t sizeof_TC, size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, bool print_config) { return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, - ranges_np, print_config)(); -} - -// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote -// memory accesses or false sharing, unless there are insufficient per-package -// rows for that. -static size_t NPMultiple(const Allocator& allocator, size_t N, - size_t sizeof_TC, size_t nr, size_t num_packages) { - size_t np_multiple = allocator.BasePageBytes() / sizeof_TC; - // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For - // `N` < 4096, this can cause significant load imbalance. If split unevenly, - // choose a smaller multiple. - if (N % (np_multiple * num_packages)) { - const size_t min_multiple = allocator.LineBytes() / sizeof_TC; - np_multiple = - PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple); - if (HWY_UNLIKELY(np_multiple == 0)) { - np_multiple = min_multiple; - } - // This happens in tests with small N, hence do not assert. - if (N % (np_multiple * num_packages) && N >= 128) { - static std::atomic_flag warned = ATOMIC_FLAG_INIT; - if (!warned.test_and_set()) { - HWY_WARN( - "NPMultiple: N=%zu still not divisible by np_multiple=%zu * " - "num_packages=%zu\n", - N, np_multiple, num_packages); - } - np_multiple = nr; - } - } - return np_multiple; -} - -IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, - size_t N, size_t sizeof_TC, size_t nr) { - const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages()); - return StaticPartition( - IndexRange(0, N), num_packages, - NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages)); + print_config)(); } MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { // Create storage per cluster. This only applies to in-cluster parallelism. // For nested and sequential parallelism, a single MMStorage is used. const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers(); + per_cluster.resize(num_clusters); storage.reserve(num_clusters); for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { - storage.push_back(MMStorage(ctx)); + storage.push_back(MMStorage(ctx.allocator)); row_ptrs.push_back(hwy::AllocateAligned(kMaxBatchSize)); // C } @@ -423,20 +376,15 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) { PROFILER_ZONE("Startup.BindB"); - const IndexRangePartition ranges_np = - MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR); - for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { - const IndexRange& rows_b = ranges_np.Range(pkg_idx); - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - uintptr_t begin = reinterpret_cast(B.RowBytes(rows_b.begin())); - uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes(); - // B row padding is less than the page size, so only bind the subset that - // is page-aligned. - begin = hwy::RoundUpTo(begin, allocator.BasePageBytes()); - end = hwy::RoundDownTo(end, allocator.BasePageBytes()); - if (HWY_LIKELY(begin != end)) { - allocator.BindMemory(reinterpret_cast(begin), end - begin, node); - } + const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node(); + uintptr_t begin = reinterpret_cast(B.RowBytes(0)); + uintptr_t end = begin + B.Rows() * B.Stride() * B.ElementBytes(); + // B row padding is less than the page size, so only bind the subset that + // is page-aligned. + begin = hwy::RoundUpTo(begin, allocator.BasePageBytes()); + end = hwy::RoundDownTo(end, allocator.BasePageBytes()); + if (HWY_LIKELY(begin != end)) { + allocator.BindMemory(reinterpret_cast(begin), end - begin, node); } } @@ -447,25 +395,20 @@ void BindC(ThreadingContext& ctx, MatPtr& C) { PROFILER_ZONE("Startup.BindC"); - const IndexRangePartition ranges_np = - MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR); - bool ok = true; - for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { - const IndexRange& cols_c = ranges_np.Range(pkg_idx); - // `BindMemory` requires page alignment. These are in bytes. - const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(), - allocator.BasePageBytes()); - const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), - allocator.BasePageBytes()); + const IndexRange cols_c(0, C.Cols()); + // `BindMemory` requires page alignment. These are in bytes. + const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(), + allocator.BasePageBytes()); + const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(), + allocator.BasePageBytes()); - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - for (size_t im = 0; im < C.Rows(); ++im) { - ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); - } + const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node(); + bool ok = true; + for (size_t im = 0; im < C.Rows(); ++im) { + ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node); } if (HWY_UNLIKELY(!ok)) { - HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", C.Rows(), C.Cols(), - ranges_np.NumTasks()); + HWY_WARN("Failed to bind C (%zux%zu).", C.Rows(), C.Cols()); } } diff --git a/ops/matmul.h b/ops/matmul.h index 16cb51c..8c7d724 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -45,17 +45,18 @@ namespace gcpp { // This and `mr` are limited by the number of registers, which is generally // 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in // `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`. -constexpr size_t kNR = 4; +HWY_INLINE_VAR constexpr size_t kNR = 4; // Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because // we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile. // In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions // that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`, // or less on ISAs with fewer registers, or for the last few rows of A. -static constexpr size_t kMaxMR = 4; +HWY_INLINE_VAR constexpr size_t kMaxMR = 4; -IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, - size_t N, size_t sizeof_TC, size_t nr); +// Upper bound for per-worker B storage on the stack. Chosen such that one row +// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. +HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; struct MMOptions { uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`. @@ -66,12 +67,12 @@ struct MMOptions { struct MMParallelNone { template - void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, - const Func& func) const { + void ForN(ThreadingContext& ctx, const IndexRange& range_n, + size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster(); - func(range_np, worker); + func(range_n, worker); } template @@ -102,9 +103,8 @@ struct MMParallelNone { struct MMParallelWithinCluster { template - void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, size_t cluster_idx, - const Func& func) const { + void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, + size_t inner_tasks, size_t cluster_idx, const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); const size_t pkg_idx = 0; @@ -112,7 +112,7 @@ struct MMParallelWithinCluster { const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); const IndexRangePartition worker_ranges = StaticPartition( - range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + range_n, cluster.NumWorkers() * inner_tasks, n_multiple); ParallelizeOneRange(worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { func(worker_range, base + worker); @@ -156,17 +156,16 @@ struct MMParallelWithinCluster { }; struct MMParallelHierarchical { - // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is + // Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is // the granularity of per-cluster tasks. Calls `func(worker_range, worker)`. template - void ForNP(ThreadingContext& ctx, const IndexRange& range_np, - size_t nx_multiple, size_t inner_tasks, - HWY_MAYBE_UNUSED size_t caller_cluster_idx, - const Func& func) const { + void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple, + size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx, + const Func& func) const { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); HWY_DASSERT(caller_cluster_idx == 0); - // Single cluster: parallel-for over static partition of `range_np`. + // Single cluster: parallel-for over static partition of `range_n`. const size_t pkg_idx = 0; hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); @@ -174,7 +173,7 @@ struct MMParallelHierarchical { const size_t cluster_idx = 0; hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const IndexRangePartition worker_ranges = StaticPartition( - range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); + range_n, cluster.NumWorkers() * inner_tasks, n_multiple); return ParallelizeOneRange( worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { @@ -182,18 +181,18 @@ struct MMParallelHierarchical { }); } - // Assign each cluster a sub-range of `range_np` (typically hundreds). - const IndexRangePartition nx_ranges = - StaticPartition(range_np, num_clusters, nx_multiple); + // Assign each cluster a sub-range of `range_n` (typically hundreds). + const IndexRangePartition n_ranges = + StaticPartition(range_n, num_clusters, n_multiple); ParallelizeOneRange( - nx_ranges, all_clusters, - [&](const IndexRange& nx_range, const size_t cluster_idx) { + n_ranges, all_clusters, + [&](const IndexRange& n_range, const size_t cluster_idx) { hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx); const size_t cluster_base = cluster_idx * ctx.pools.MaxWorkersPerCluster(); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( - nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); + n_range, cluster.NumWorkers() * inner_tasks, n_multiple); ParallelizeOneRange( worker_ranges, cluster, [&](const IndexRange& worker_range, size_t worker) { @@ -304,50 +303,29 @@ class StridedView { using StridedViewBF = StridedView; using StridedViewD = StridedView; -// Per-package storage for packed A. class MMStorage { public: // Compile-time bounds on matrix columns to enable pre-allocating storage // and reusing it across `MatMul` calls. static constexpr size_t kMaxK = 64 * 1024; - // Upper bound for per-worker B storage on the stack. Chosen such that one row - // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. - static constexpr size_t kMaxKC = 8 * 1024; - // Internally threaded; must not be called concurrently with the same - // `ThreadingContext` (used via `parallel`). - MMStorage(ThreadingContext& ctx) { - Allocator& allocator = ctx.allocator; - const size_t pkg_idx = 0; + MMStorage(const Allocator& allocator) + // 0.5 GiB. Must be padded, see `DoDecompressA`. + : A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator, + MatPadding::kOdd) {} - // 0.5 GiB per package. Must be padded, see `DoDecompressA`. - pkg_A_[pkg_idx].reset(new MatStorageT( - "pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd)); - - if (allocator.ShouldBind()) { - const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node(); - size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * - pkg_A_[pkg_idx]->ElementBytes(); - bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes()); - if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { - HWY_WARN("Failed to bind memory for package %zu", pkg_idx); - } - } - } - - // Returns per-package matrix view. Converting A=F32 to BF16 up-front is - // faster than on-the-fly when native BF16 is available: it only happens once, - // not per B tile row, and the cache footprint is smaller. - StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { + // Returns matrix view. Converting A=F32 to BF16 up-front is faster than + // on-the-fly when native BF16 is available: it only happens once, not per B + // tile row, and the cache footprint is smaller. + StridedViewBF A(const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxBatchSize); HWY_DASSERT(extents.cols <= kMaxK); - HWY_DASSERT(pkg_A_[pkg_idx] != nullptr); - return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), - extents.cols, pkg_A_[pkg_idx]->Stride()); + return StridedViewBF(const_cast(A_.Row(0)), extents.cols, + A_.Stride()); } private: - std::unique_ptr> pkg_A_[kMaxPackages]; + MatStorageT A_; }; //------------------------------------------------------------------------------ @@ -433,7 +411,7 @@ class MMConfig { MMConfig() = default; // for std::vector // `mr` is the number of A rows per call to `MMKernel::LoopKC`. // `MMOrder` is how to parallelize the outer loops. - // `inner_tasks` chooses the within-cluster task granularity in `ForNP`. + // `inner_tasks` chooses the within-cluster task granularity in `ForN`. MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc, size_t kc_multiple, size_t nc_multiple, MMOrder order, int inner_tasks) @@ -470,8 +448,8 @@ class MMConfig { IndexRangePartition RangesOfKC(size_t K) const { return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_); } - IndexRangePartition RangesOfNC(IndexRange range_np) const { - return MaxSizePartition(range_np, nc_, nc_multiple_); + IndexRangePartition RangesOfNC(size_t N) const { + return MaxSizePartition(IndexRange(0, N), nc_, nc_multiple_); } MMOrder Order() const { return order_; } @@ -501,9 +479,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing std::vector MMCandidates(const Allocator& allocator, size_t M, size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, - const IndexRangePartition& ranges_np, - bool print_config); + size_t max_mr, size_t nr, bool print_config); // State machine for choosing the best `TConfig`, which is `MMConfig` for the // main MatMul autotuner. @@ -609,7 +585,7 @@ class MMAutoTune { // `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range, // but choosing the same config for a larger M can result in multiple MC ranges. // Thus M less than this must have unique keys/configs. -static constexpr size_t kMaxTilesM = 8; +HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8; // Map of previously seen dimensions to index via linear search. class MMKeys { @@ -636,8 +612,8 @@ class MMKeys { return key; } - // We leave the search to callers so they can use dynamic-dispatched SIMD, - // which is not possible in this header. + // We leave the search to callers so they can use per-target SIMD, which is + // not possible in this header. hwy::Span Keys() const { return hwy::Span(keys_.get(), num_unique_); } @@ -674,26 +650,17 @@ class MMKeys { // Per-MatMul-shape state. struct MMPerKey { - MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N, - size_t sizeof_TC, size_t nr) - : ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) { - HWY_DASSERT(ranges_np.NumTasks() <= max_packages); - } - - // Only profile if enabled and the main autotuner finished (the par_a - // autotuner is per-package and we want to avoid synchronization). + // Only profile if enabled and the main autotuner finished. `autotune_par_a` + // might not be active if inputs are all BF16. bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); } - const IndexRangePartition ranges_np; MMAutoTune autotune; - MMAutoTune autotune_par_a[kMaxPackages]; + MMAutoTune autotune_par_a; }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive // `MatMulEnv`. struct MatMulEnv { - // Internally threaded; must not be called concurrently with the same - // `ThreadingContext`. explicit MatMulEnv(ThreadingContext& ctx); ThreadingContext& ctx; @@ -707,8 +674,13 @@ struct MatMulEnv { bool print_best = false; std::vector storage; - MMKeys keys[kMaxClusters]; - std::vector per_key[kMaxClusters]; + + struct PerCluster { + MMKeys keys; + std::vector per_key; + HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing + }; + std::vector per_cluster; // Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`. // Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV @@ -739,6 +711,7 @@ struct MMArgs { double scale; const float* HWY_RESTRICT add; + MMOptions options; size_t line_bytes; }; diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 665e337..373f8aa 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -275,37 +275,28 @@ void TestTiny() { if (first_target == 0) first_target = HWY_TARGET; if (HWY_TARGET != first_target) return; - for (size_t max_packages : {1, 2}) { - ThreadingArgs threading_args; - threading_args.bind = Tristate::kTrue; - threading_args.max_packages = max_packages; - ThreadingContext ctx(threading_args); - MatMulEnv env(ctx); - NestedPools& pools = env.ctx.pools; + ThreadingArgs threading_args; + threading_args.bind = Tristate::kTrue; + ThreadingContext ctx(threading_args); + MatMulEnv env(ctx); + NestedPools& pools = env.ctx.pools; - if constexpr (GEMMA_DISABLE_TOPOLOGY || kMaxPackages == 1) { - if (max_packages == 2) break; // we only have one package - } else { - // If less than the limit, we have already tested all num_packages. - if (env.ctx.topology.FullTopology().packages.size() < max_packages) break; - } - fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages, - env.ctx.topology.TopologyString(), pools.PinString()); + fprintf(stderr, "TestTiny: %s %s\n", env.ctx.topology.TopologyString(), + pools.PinString()); - pools.MaybeStartSpinning(threading_args.spin); + pools.MaybeStartSpinning(threading_args.spin); - for (size_t M = 1; M <= 12; ++M) { - for (size_t K = 1; K <= 64; K *= 2) { - for (size_t N = 4; N <= 64; N += max_packages * 4) { - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - TestMatMul(M, K, N, /*add=*/false, env, __LINE__); - } + for (size_t M = 1; M <= 12; ++M) { + for (size_t K = 1; K <= 64; K *= 2) { + for (size_t N = 4; N <= 64; N += 4) { + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); + TestMatMul(M, K, N, /*add=*/false, env, __LINE__); } } - pools.MaybeStopSpinning(threading_args.spin); } + pools.MaybeStopSpinning(threading_args.spin); } void TestAllMatMul() { diff --git a/util/allocator.cc b/util/allocator.cc index df2575e..f8bfdd5 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -160,11 +160,10 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) { // - supported by the OS (currently Linux only), // - the page size is known and 'reasonably small', preferably less than // a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB. - // - we successfully detected topology and there are multiple nodes; - // - there are multiple packages, because we shard by package_idx. + // - we successfully detected topology and there are multiple nodes. if constexpr (GEMMA_BIND) { if ((base_page_bytes_ != 0 && base_page_bytes_ <= 16 * 1024) && - topology.NumNodes() > 1 && topology.NumPackages() > 1) { + topology.NumNodes() > 1) { if (enable_bind) { // Ensure pages meet the alignment requirements of `AllocBytes`. HWY_ASSERT(base_page_bytes_ >= quantum_bytes_); diff --git a/util/allocator.h b/util/allocator.h index bf904c5..42e261c 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -149,7 +149,7 @@ class Allocator { } // Returns whether `BindMemory` can/should be called, i.e. we have page-level - // control over memory placement and multiple packages and NUMA nodes. + // control over memory placement and multiple NUMA nodes. bool ShouldBind() const { return should_bind_; } // Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is diff --git a/util/basics.h b/util/basics.h index 7cdc17c..c8858e5 100644 --- a/util/basics.h +++ b/util/basics.h @@ -30,13 +30,6 @@ namespace gcpp { -// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the -// runtime `max_packages` does not exceed this. MatMul's outer per-package loop -// is disabled if this is 1. -HWY_INLINE_VAR constexpr size_t kMaxPackages = 1; - -HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink - // TODO: extend to 16k after updating non_eos. HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096; diff --git a/util/mat.h b/util/mat.h index b0de72d..9d838e2 100644 --- a/util/mat.h +++ b/util/mat.h @@ -455,6 +455,7 @@ class MatOwner { template class MatStorageT : public MatPtrT { public: + MatStorageT() = default; // for std::vector in Activations. MatStorageT(const char* name, Extents2D extents, const Allocator& allocator, MatPadding padding) : MatPtrT(name, extents) { diff --git a/util/threading_context.h b/util/threading_context.h index 847ce81..6bd6936 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -25,7 +25,7 @@ // IWYU pragma: begin_exports #include "util/allocator.h" #include "util/args.h" -#include "util/basics.h" // Tristate, kMaxPackages +#include "util/basics.h" // Tristate #include "util/threading.h" #include "util/topology.h" #include "hwy/profiler.h" @@ -41,7 +41,7 @@ class ThreadingArgs : public ArgsBase { // For BoundedTopology: size_t skip_packages; - size_t max_packages; + size_t max_packages = 1; size_t skip_clusters; size_t max_clusters; size_t skip_lps; @@ -58,13 +58,9 @@ class ThreadingArgs : public ArgsBase { void ForEach(const Visitor& visitor) { // These can be used to partition CPU packages/sockets and their // clusters/CCXs across several program instances. The default is to use - // all available resources on one package. Note that `kMaxPackages` is an - // upper bound on `max_packages`. + // all available resources on the first package. visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); - visitor(max_packages, "max_packages", size_t{1}, - "Max sockets to use; default = 1, 0 = unlimited.", 2); - HWY_ASSERT(max_packages <= kMaxPackages); visitor(skip_clusters, "skip_clusters", size_t{0}, "Index of the first CCX to use; default 0 = unlimited.", 2); visitor(max_clusters, "max_clusters", size_t{0}, @@ -105,7 +101,7 @@ struct ThreadingContext { hwy::Profiler& profiler; // Detects topology, subject to limits imposed by user-specified `args`. - // For example, if `args.max_packages` is 1, then `topology.NumPackages()` + // For example, if `args.max_clusters` is 1, then `topology.NumClusters()` // will be 1 regardless of the actual system topology. BoundedTopology topology;