Remove kMaxPackages and per-package-related code

matmul: remove kMaxClusters, dynamic allocation
PiperOrigin-RevId: 802950348
This commit is contained in:
Jan Wassenberg 2025-09-04 03:32:35 -07:00 committed by Copybara-Service
parent 7263ab8445
commit 4be4799727
11 changed files with 177 additions and 270 deletions

View File

@ -21,7 +21,6 @@
#include <stdint.h>
#include <atomic>
#include <memory>
#include <vector>
#include "gemma/configs.h" // ModelConfig
@ -62,11 +61,12 @@ struct AttentionActivations {
// Returns the scale value to use for the query in the attention computation.
// Also called by ops_test.
static inline float ChooseQueryScale(const ModelConfig& config) {
const LayerConfig& layer_config = config.layer_configs[0];
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
return 1.0f /
sqrtf(static_cast<float>(config.model_dim / layer_config.heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
return 1.0f / sqrtf(static_cast<float>(layer_config.qkv_dim));
}
AttentionActivations(

View File

@ -1101,7 +1101,6 @@ void TestAllDot() {
// Limit workers because we only support `kMaxWorkers`.
ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext ctx(threading_args);

View File

@ -570,10 +570,29 @@ struct MMImpl {
// Returns existing entry for the given key or -1.
static HWY_INLINE intptr_t IndexOfKey(MMKeys::Key key, const MMKeys& keys) {
const hwy::Span<const uint64_t> all_keys = keys.Keys();
// TODO: SIMD scan
for (size_t i = 0; i < all_keys.size(); ++i) {
if (all_keys[i] == key) return static_cast<intptr_t>(i);
const hn::ScalableTag<uint64_t> d;
using V = hn::Vec<decltype(d)>;
const V broadcasted = Set(d, key);
const size_t N = hn::Lanes(d);
size_t i = 0;
if (all_keys.size() >= N) {
for (; i <= all_keys.size() - N; i += N) {
const intptr_t pos = hn::FindFirstTrue(
d, hn::Eq(broadcasted, hn::LoadU(d, &all_keys[i])));
if (pos >= 0) return static_cast<intptr_t>(i) + pos;
}
}
const size_t remaining = all_keys.size() - i;
if (HWY_LIKELY(remaining > 0)) {
HWY_DASSERT(remaining < N);
const V v = hn::LoadN(d, &all_keys[i], remaining);
const intptr_t pos = hn::FindFirstTrue(d, hn::Eq(broadcasted, v));
if (pos >= 0) return static_cast<intptr_t>(i) + pos;
}
return -1;
}
@ -582,6 +601,15 @@ struct MMImpl {
args.env->ctx.pools.MaxWorkersPerCluster();
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
template <class Func>
static void DispatchParallelism(ParallelismStrategy parallelism,
const Func& func) {
@ -651,7 +679,7 @@ struct MMImpl {
DispatchParallelism(
args.options.parallelism, [&](const auto& parallel) {
parallel.ForNP(args.env->ctx, all_K, multiple_K, inner_tasks,
parallel.ForN(args.env->ctx, all_K, multiple_K, inner_tasks,
args.options.cluster_idx,
[&](const IndexRange& range_K, size_t worker) {
do_range(all_M, range_K, worker);
@ -676,7 +704,7 @@ struct MMImpl {
static HWY_INLINE void DecompressA(const MatPtrT<float>& A,
const StridedViewBF A_view,
const MMArgs& args) {
MMAutoTune<MMParA>& autotune = args.per_key->autotune_par_a[/*pkg_idx=*/0];
MMAutoTune<MMParA>& autotune = args.per_key->autotune_par_a;
if (HWY_LIKELY(autotune.Best())) {
return DoDecompressA(A, A_view, *autotune.Best(), args);
@ -703,15 +731,6 @@ struct MMImpl {
}
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
template <typename T>
static StridedView<T> View(const MatPtrT<T>& AB, size_t r, size_t c,
size_t cols) {
HWY_DASSERT(c < AB.Cols());
HWY_DASSERT(cols <= AB.Cols() - c);
return StridedView<T>(const_cast<T*>(AB.Row(r)) + c, cols, AB.Stride());
}
template <typename TA>
static HWY_INLINE StridedViewBF MaybeDecompressA(const MatPtrT<TA>& A,
const MMArgs& args) {
@ -723,8 +742,7 @@ struct MMImpl {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16.
const StridedViewBF A_view =
args.env->storage[args.options.cluster_idx].A(/*pkg_idx=*/0,
A.Extents());
args.env->storage[args.options.cluster_idx].A(A.Extents());
DecompressA(A, A_view, args);
return A_view;
}
@ -735,17 +753,16 @@ struct MMImpl {
// loops over the inner KC and MC. Member variables avoid long argument lists.
class MMState {
public:
MMState(const Extents2D A, const MMArgs& args, const MMConfig& config)
MMState(const Extents2D A, const size_t B_rows, const MMArgs& args,
const MMConfig& config)
: args_(args),
range_np_(args.per_key->ranges_np.Range(/*pkg_idx=*/0)),
range_n_(0, B_rows),
mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.rows)),
ranges_kc_(config.RangesOfKC(A.cols)),
ranges_nc_(config.RangesOfNC(range_np_)),
ranges_nc_(config.RangesOfNC(B_rows)),
order_(config.Order()),
inner_tasks_(config.InnerTasks()) {
HWY_DASSERT(args.per_key->ranges_np.NumTasks() == 1);
}
inner_tasks_(config.InnerTasks()) {}
// Called from `MatMul` from two places: either with the next autotune config,
// or with the best config.
@ -768,12 +785,12 @@ class MMState {
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
// allocation avoids passing a worker index.
static constexpr size_t B_stride_max_ =
MMStorage::kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16);
kMaxKC + 2 * Allocator::MaxLineBytes() / sizeof(BF16);
static constexpr size_t B_storage_max_ = kNR * B_stride_max_;
// Granularity of `ForNP`. B rows produce C columns, so we
// Granularity of `ForN`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing.
size_t MultipleNP(size_t sizeof_TC) const {
size_t MultipleN(size_t sizeof_TC) const {
return HWY_MAX(kNR, args_.line_bytes / sizeof_TC);
}
@ -812,8 +829,8 @@ class MMState {
Stride(MatPadding::kOdd, K, sizeof(BF16), args_.line_bytes);
// Similar to `loop_nc` below, but here we hoisted `A_view`.
parallel.ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
parallel.ForN(
args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_,
args_.options.cluster_idx,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
@ -861,8 +878,8 @@ class MMState {
}
};
parallel.ForNP(
args_.env->ctx, range_np_, MultipleNP(sizeof(TC)), inner_tasks_,
parallel.ForN(
args_.env->ctx, range_n_, MultipleN(sizeof(TC)), inner_tasks_,
args_.options.cluster_idx,
[&](const IndexRange& range_nc, size_t worker) HWY_ATTR {
MMZone mm_zone;
@ -881,7 +898,7 @@ class MMState {
});
}
// Parallel loops over mc/nc blocks of M/range_np, single K.
// Parallel loops over mc/nc blocks of M/range_n, single K.
// Fills `mc x nc` sections of C directly, in parallel.
template <typename TB, typename TC, class ParallelT>
HWY_INLINE void DoNT_MT(ParallelT parallel, const StridedViewBF A,
@ -923,7 +940,7 @@ class MMState {
const MatPtrT<TB>& B, RowPtrs<TC> C_rows) const {
static const auto zone = args_.env->ctx.profiler.AddZone("MM.NT_MT_K");
const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
HWY_DASSERT(kc_max <= kMaxKC);
const size_t B_stride =
Stride(MatPadding::kOdd, kc_max, sizeof(BF16), args_.line_bytes);
// Sequential loop over NC/MC/KC, for when the M/N loops are
@ -1002,7 +1019,7 @@ class MMState {
const MMArgs args_; // copy for locality
const IndexRange range_np_;
const IndexRange range_n_;
// From MMConfig:
const size_t mr_;
const IndexRangePartition ranges_mc_;
@ -1036,38 +1053,33 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
const Allocator& allocator = env.ctx.allocator;
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx];
RowPtrs<TC> C_rows =
GetOrSetTempRowPtrs(C, env.row_ptrs[options.cluster_idx]);
const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Rows();
const size_t K = A.Cols();
const size_t N = B.Rows();
const MMKeys::Key key = MMKeys::KeyFromDims(M, K, N);
intptr_t index = MMImpl::IndexOfKey(key, env.keys[options.cluster_idx]);
intptr_t index = MMImpl::IndexOfKey(key, per_cluster.keys);
// First time we see this shape/key.
if (HWY_UNLIKELY(index < 0)) {
env.keys[options.cluster_idx].Append(key, allocator);
size_t max_packages = kMaxPackages;
// For low-batch, multiple sockets only help if binding is enabled.
if (!allocator.ShouldBind() && M <= 4) {
max_packages = 1;
}
per_cluster.keys.Append(key, allocator);
// invalidates `MMAutoTune::Best()`
std::vector<MMPerKey>& stored_keys = env.per_key[options.cluster_idx];
index = stored_keys.size();
stored_keys.push_back(MMPerKey(env.ctx, max_packages, N, sizeof(TC), kNR));
std::vector<MMPerKey>& per_keys = per_cluster.per_key;
index = per_keys.size();
per_keys.push_back(MMPerKey());
}
MMPerKey& per_key = env.per_key[options.cluster_idx][index];
MMPerKey& per_key = per_cluster.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.Scale(),
add, options);
if (HWY_LIKELY(tuner.Best())) {
const MMState state(A.Extents(), args, *tuner.Best());
const MMState state(A.Extents(), B.Rows(), args, *tuner.Best());
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
state.DispatchParallelism(A_view, B, C_rows);
return &per_key;
@ -1092,12 +1104,12 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
}
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), kMaxMR,
kNR, per_key.ranges_np, env.print_config));
kNR, env.print_config));
}
const MMConfig& cfg = tuner.NextConfig();
const uint64_t t0 = hwy::timer::Start();
MMState state(A.Extents(), args, cfg);
MMState state(A.Extents(), B.Rows(), args, cfg);
const StridedViewBF A_view = MMImpl::MaybeDecompressA(A, args);
state.DispatchParallelism(A_view, B, C_rows);
const uint64_t t1 =

View File

@ -21,7 +21,6 @@
#include <stdint.h>
#include <stdio.h>
#include <atomic>
#include <vector>
#include "util/allocator.h"
@ -65,7 +64,7 @@ class GenerateCandidates {
public:
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
bool print_config)
: allocator_(allocator),
M_(M),
K_(K),
@ -79,7 +78,6 @@ class GenerateCandidates {
// up to the line size. Both A and B are BF16.
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
nc_multiple_(allocator.StepBytes() / sizeof_TC),
ranges_np_(ranges_np),
print_config_(print_config) {}
std::vector<MMConfig> operator()() const {
@ -177,8 +175,7 @@ class GenerateCandidates {
allocator_.L1Bytes() * (sizeof(BF16) + sizeof(SfpStream));
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
kc_max =
RoundDownWithFloor(HWY_MIN(kc_max, MMStorage::kMaxKC), kc_multiple_);
kc_max = RoundDownWithFloor(HWY_MIN(kc_max, kMaxKC), kc_multiple_);
kc_max = HWY_MIN(kc_max, K_);
SizeVec all_kc(1, kc_max);
@ -258,32 +255,30 @@ class GenerateCandidates {
// The number of (possibly L3 resident) B rows per `NT_MT` task.
SizeVec NC(size_t mr, size_t mc, size_t kc, MMOrder order) const {
const size_t np_max = ranges_np_.TaskSize();
size_t nc_max = np_max;
size_t nc_max = N_;
// Only if there will be reuse of B: choose the largest `nc_max` (C cols)
// such that `nc x kc` of B and `mc x nc` of `partial` or `C` fit in L3.
// Otherwise, leave it unbounded.
if (M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * sizeof_TC_);
nc_max =
HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), np_max);
nc_max = HWY_MIN(hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc), N_);
}
HWY_DASSERT(nc_max != 0);
nc_max = RoundDownWithFloor(nc_max, nc_multiple_);
// If there are going to be multiple ranges, anything more than half would
// be imbalanced and suboptimal.
if (nc_max < np_max && nc_max >= np_max / 2) {
nc_max = RoundDownWithFloor(np_max / 2, nc_multiple_);
if (nc_max < N_ && nc_max >= N_ / 2) {
nc_max = RoundDownWithFloor(N_ / 2, nc_multiple_);
}
// Non-block calls ForNP, which ignores `range_nc` and uses `range_np`.
if (!IsBlock(order)) return SizeVec(1, np_max);
if (!IsBlock(order)) return SizeVec(1, N_);
SizeVec all_nc(1, nc_max);
// Avoid proposing nc > N.
if (np_max > nc_multiple_) {
if (N_ > nc_multiple_) {
// Large L3, but its behavior and characteristics varies across platforms,
// hence autotune a wider range of nc than the other dimensions.
size_t reps = 10;
@ -292,8 +287,7 @@ class GenerateCandidates {
size_t prev = nc_max;
for (size_t rep = 0; rep < reps; ++rep) {
const size_t div =
PrevDivisor(nc_multiple_, prev, np_max, nc_multiple_);
const size_t div = PrevDivisor(nc_multiple_, prev, N_, nc_multiple_);
prev = div ? div : RoundDownWithFloor(prev / 2, nc_multiple_);
all_nc.push_back(prev);
if (prev == nc_multiple_) break;
@ -346,8 +340,6 @@ class GenerateCandidates {
const size_t kc_multiple_;
const size_t nc_multiple_;
IndexRangePartition ranges_np_;
const bool print_config_;
};
@ -357,58 +349,19 @@ class GenerateCandidates {
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config) {
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
ranges_np, print_config)();
}
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package
// rows for that.
static size_t NPMultiple(const Allocator& allocator, size_t N,
size_t sizeof_TC, size_t nr, size_t num_packages) {
size_t np_multiple = allocator.BasePageBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
// choose a smaller multiple.
if (N % (np_multiple * num_packages)) {
const size_t min_multiple = allocator.LineBytes() / sizeof_TC;
np_multiple =
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
if (HWY_UNLIKELY(np_multiple == 0)) {
np_multiple = min_multiple;
}
// This happens in tests with small N, hence do not assert.
if (N % (np_multiple * num_packages) && N >= 128) {
static std::atomic_flag warned = ATOMIC_FLAG_INIT;
if (!warned.test_and_set()) {
HWY_WARN(
"NPMultiple: N=%zu still not divisible by np_multiple=%zu * "
"num_packages=%zu\n",
N, np_multiple, num_packages);
}
np_multiple = nr;
}
}
return np_multiple;
}
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t N, size_t sizeof_TC, size_t nr) {
const size_t num_packages = HWY_MIN(max_packages, ctx.pools.NumPackages());
return StaticPartition(
IndexRange(0, N), num_packages,
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
print_config)();
}
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
// Create storage per cluster. This only applies to in-cluster parallelism.
// For nested and sequential parallelism, a single MMStorage is used.
const size_t num_clusters = ctx.pools.AllClusters(/*pkg_idx=*/0).NumWorkers();
per_cluster.resize(num_clusters);
storage.reserve(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
storage.push_back(MMStorage(ctx));
storage.push_back(MMStorage(ctx.allocator));
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(kMaxBatchSize)); // C
}
@ -423,13 +376,9 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
PROFILER_ZONE("Startup.BindB");
const IndexRangePartition ranges_np =
MMRangesOfNP(ctx, kMaxPackages, B.Rows(), sizeof_TC, kNR);
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(rows_b.begin()));
uintptr_t end = begin + rows_b.Num() * B.Stride() * B.ElementBytes();
const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node();
uintptr_t begin = reinterpret_cast<uintptr_t>(B.RowBytes(0));
uintptr_t end = begin + B.Rows() * B.Stride() * B.ElementBytes();
// B row padding is less than the page size, so only bind the subset that
// is page-aligned.
begin = hwy::RoundUpTo(begin, allocator.BasePageBytes());
@ -437,7 +386,6 @@ void BindB(ThreadingContext& ctx, MatPtr& B, size_t sizeof_TC) {
if (HWY_LIKELY(begin != end)) {
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
}
}
}
// C is BF16/float
@ -447,25 +395,20 @@ void BindC(ThreadingContext& ctx, MatPtr& C) {
PROFILER_ZONE("Startup.BindC");
const IndexRangePartition ranges_np =
MMRangesOfNP(ctx, kMaxPackages, C.Cols(), C.ElementBytes(), kNR);
bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
const IndexRange cols_c(0, C.Cols());
// `BindMemory` requires page alignment. These are in bytes.
const size_t begin = hwy::RoundUpTo(cols_c.begin() * C.ElementBytes(),
allocator.BasePageBytes());
const size_t end = hwy::RoundDownTo(cols_c.end() * C.ElementBytes(),
allocator.BasePageBytes());
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
const size_t node = ctx.topology.GetCluster(/*pkg_idx=*/0, 0).Node();
bool ok = true;
for (size_t im = 0; im < C.Rows(); ++im) {
ok &= allocator.BindMemory(C.RowBytes(im) + begin, end - begin, node);
}
}
if (HWY_UNLIKELY(!ok)) {
HWY_WARN("Failed to bind C (%zux%zu), %zu packages.", C.Rows(), C.Cols(),
ranges_np.NumTasks());
HWY_WARN("Failed to bind C (%zux%zu).", C.Rows(), C.Cols());
}
}

View File

@ -45,17 +45,18 @@ namespace gcpp {
// This and `mr` are limited by the number of registers, which is generally
// 32 but 16 for AVX2. `kNR` == 4 enables the `StoreInterleaved4` transpose in
// `MMStoreHorizontalSumsIntoC`. We ensure `C.Cols() % kNR == 0`.
constexpr size_t kNR = 4;
HWY_INLINE_VAR constexpr size_t kNR = 4;
// Choosing `kMaxMR == kNR` minimizes the ratio of loads to FMA, because
// we load `kNR + kMaxMR` vectors per `kMaxMR * kNR` element tile.
// In general, `M` (batch size) is not a multiple of `kMaxMR`. Thus functions
// that load or store a tile are parameterized on `kRowsAC`: usually `kMaxMR`,
// or less on ISAs with fewer registers, or for the last few rows of A.
static constexpr size_t kMaxMR = 4;
HWY_INLINE_VAR constexpr size_t kMaxMR = 4;
IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
size_t N, size_t sizeof_TC, size_t nr);
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024;
struct MMOptions {
uint32_t cluster_idx = 0; // for `parallelism == kWithinCluster`.
@ -66,12 +67,12 @@ struct MMOptions {
struct MMParallelNone {
template <class Func>
void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t cluster_idx,
void ForN(ThreadingContext& ctx, const IndexRange& range_n,
size_t /*n_multiple*/, size_t inner_tasks, size_t cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t worker = cluster_idx * ctx.pools.MaxWorkersPerCluster();
func(range_np, worker);
func(range_n, worker);
}
template <class Func>
@ -102,9 +103,8 @@ struct MMParallelNone {
struct MMParallelWithinCluster {
template <class Func>
void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks, size_t cluster_idx,
const Func& func) const {
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, size_t cluster_idx, const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
const size_t pkg_idx = 0;
@ -112,7 +112,7 @@ struct MMParallelWithinCluster {
const size_t base = cluster_idx * ctx.pools.MaxWorkersPerCluster();
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t worker) {
func(worker_range, base + worker);
@ -156,17 +156,16 @@ struct MMParallelWithinCluster {
};
struct MMParallelHierarchical {
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
// Cluster/CCX-aware parallel-for over B rows in `range_n`. `n_multiple` is
// the granularity of per-cluster tasks. Calls `func(worker_range, worker)`.
template <class Func>
void ForNP(ThreadingContext& ctx, const IndexRange& range_np,
size_t nx_multiple, size_t inner_tasks,
HWY_MAYBE_UNUSED size_t caller_cluster_idx,
void ForN(ThreadingContext& ctx, const IndexRange& range_n, size_t n_multiple,
size_t inner_tasks, HWY_MAYBE_UNUSED size_t caller_cluster_idx,
const Func& func) const {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
HWY_DASSERT(caller_cluster_idx == 0);
// Single cluster: parallel-for over static partition of `range_np`.
// Single cluster: parallel-for over static partition of `range_n`.
const size_t pkg_idx = 0;
hwy::ThreadPool& all_clusters = ctx.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers();
@ -174,7 +173,7 @@ struct MMParallelHierarchical {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
range_n, cluster.NumWorkers() * inner_tasks, n_multiple);
return ParallelizeOneRange(
worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t worker) {
@ -182,18 +181,18 @@ struct MMParallelHierarchical {
});
}
// Assign each cluster a sub-range of `range_np` (typically hundreds).
const IndexRangePartition nx_ranges =
StaticPartition(range_np, num_clusters, nx_multiple);
// Assign each cluster a sub-range of `range_n` (typically hundreds).
const IndexRangePartition n_ranges =
StaticPartition(range_n, num_clusters, n_multiple);
ParallelizeOneRange(
nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) {
n_ranges, all_clusters,
[&](const IndexRange& n_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = ctx.pools.Cluster(pkg_idx, cluster_idx);
const size_t cluster_base =
cluster_idx * ctx.pools.MaxWorkersPerCluster();
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
n_range, cluster.NumWorkers() * inner_tasks, n_multiple);
ParallelizeOneRange(
worker_ranges, cluster,
[&](const IndexRange& worker_range, size_t worker) {
@ -304,50 +303,29 @@ class StridedView {
using StridedViewBF = StridedView<BF16>;
using StridedViewD = StridedView<double>;
// Per-package storage for packed A.
class MMStorage {
public:
// Compile-time bounds on matrix columns to enable pre-allocating storage
// and reusing it across `MatMul` calls.
static constexpr size_t kMaxK = 64 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`).
MMStorage(ThreadingContext& ctx) {
Allocator& allocator = ctx.allocator;
const size_t pkg_idx = 0;
MMStorage(const Allocator& allocator)
// 0.5 GiB. Must be padded, see `DoDecompressA`.
: A_("A_bf", Extents2D(kMaxBatchSize, kMaxK), allocator,
MatPadding::kOdd) {}
// 0.5 GiB per package. Must be padded, see `DoDecompressA`.
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxBatchSize, kMaxK), allocator, MatPadding::kOdd));
if (allocator.ShouldBind()) {
const size_t node = ctx.topology.GetCluster(pkg_idx, 0).Node();
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
pkg_A_[pkg_idx]->ElementBytes();
bytes = hwy::RoundDownTo(bytes, allocator.BasePageBytes());
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
}
}
}
// Returns per-package matrix view. Converting A=F32 to BF16 up-front is
// faster than on-the-fly when native BF16 is available: it only happens once,
// not per B tile row, and the cache footprint is smaller.
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
// Returns matrix view. Converting A=F32 to BF16 up-front is faster than
// on-the-fly when native BF16 is available: it only happens once, not per B
// tile row, and the cache footprint is smaller.
StridedViewBF A(const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxBatchSize);
HWY_DASSERT(extents.cols <= kMaxK);
HWY_DASSERT(pkg_A_[pkg_idx] != nullptr);
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
extents.cols, pkg_A_[pkg_idx]->Stride());
return StridedViewBF(const_cast<BF16*>(A_.Row(0)), extents.cols,
A_.Stride());
}
private:
std::unique_ptr<MatStorageT<BF16>> pkg_A_[kMaxPackages];
MatStorageT<BF16> A_;
};
//------------------------------------------------------------------------------
@ -433,7 +411,7 @@ class MMConfig {
MMConfig() = default; // for std::vector
// `mr` is the number of A rows per call to `MMKernel::LoopKC`.
// `MMOrder` is how to parallelize the outer loops.
// `inner_tasks` chooses the within-cluster task granularity in `ForNP`.
// `inner_tasks` chooses the within-cluster task granularity in `ForN`.
MMConfig(size_t K, size_t N, size_t mr, size_t mc, size_t kc, size_t nc,
size_t kc_multiple, size_t nc_multiple, MMOrder order,
int inner_tasks)
@ -470,8 +448,8 @@ class MMConfig {
IndexRangePartition RangesOfKC(size_t K) const {
return MaxSizePartition(IndexRange(0, K), kc_, kc_multiple_);
}
IndexRangePartition RangesOfNC(IndexRange range_np) const {
return MaxSizePartition(range_np, nc_, nc_multiple_);
IndexRangePartition RangesOfNC(size_t N) const {
return MaxSizePartition(IndexRange(0, N), nc_, nc_multiple_);
}
MMOrder Order() const { return order_; }
@ -501,9 +479,7 @@ static_assert(sizeof(MMConfig) == 32); // for faster indexing
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config);
size_t max_mr, size_t nr, bool print_config);
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
// main MatMul autotuner.
@ -609,7 +585,7 @@ class MMAutoTune {
// `MMOrder::kNT[_K]` are no longer allowed. They require a single MC range,
// but choosing the same config for a larger M can result in multiple MC ranges.
// Thus M less than this must have unique keys/configs.
static constexpr size_t kMaxTilesM = 8;
HWY_INLINE_VAR constexpr size_t kMaxTilesM = 8;
// Map of previously seen dimensions to index via linear search.
class MMKeys {
@ -636,8 +612,8 @@ class MMKeys {
return key;
}
// We leave the search to callers so they can use dynamic-dispatched SIMD,
// which is not possible in this header.
// We leave the search to callers so they can use per-target SIMD, which is
// not possible in this header.
hwy::Span<const Key> Keys() const {
return hwy::Span<const Key>(keys_.get(), num_unique_);
}
@ -674,26 +650,17 @@ class MMKeys {
// Per-MatMul-shape state.
struct MMPerKey {
MMPerKey(ThreadingContext& ctx, size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr)
: ranges_np(MMRangesOfNP(ctx, max_packages, N, sizeof_TC, nr)) {
HWY_DASSERT(ranges_np.NumTasks() <= max_packages);
}
// Only profile if enabled and the main autotuner finished (the par_a
// autotuner is per-package and we want to avoid synchronization).
// Only profile if enabled and the main autotuner finished. `autotune_par_a`
// might not be active if inputs are all BF16.
bool WantProfile() const { return PROFILER_ENABLED != 0 && autotune.Best(); }
const IndexRangePartition ranges_np;
MMAutoTune<MMConfig> autotune;
MMAutoTune<MMParA> autotune_par_a[kMaxPackages];
MMAutoTune<MMParA> autotune_par_a;
};
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`.
struct MatMulEnv {
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext`.
explicit MatMulEnv(ThreadingContext& ctx);
ThreadingContext& ctx;
@ -707,8 +674,13 @@ struct MatMulEnv {
bool print_best = false;
std::vector<MMStorage> storage;
MMKeys keys[kMaxClusters];
std::vector<MMPerKey> per_key[kMaxClusters];
struct PerCluster {
MMKeys keys;
std::vector<MMPerKey> per_key;
HWY_MAYBE_UNUSED uint8_t padding[HWY_ALIGNMENT]; // prevent false sharing
};
std::vector<PerCluster> per_cluster;
// Storage for arbitrary output rows, see `MatPtr::AllocateAndAttachRowPtrs`.
// Most MatMul callers use strided MatPtr, but GemmaAttention::ComputeQKV
@ -739,6 +711,7 @@ struct MMArgs {
double scale;
const float* HWY_RESTRICT add;
MMOptions options;
size_t line_bytes;
};

View File

@ -275,28 +275,20 @@ void TestTiny() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
threading_args.max_packages = max_packages;
ThreadingContext ctx(threading_args);
MatMulEnv env(ctx);
NestedPools& pools = env.ctx.pools;
if constexpr (GEMMA_DISABLE_TOPOLOGY || kMaxPackages == 1) {
if (max_packages == 2) break; // we only have one package
} else {
// If less than the limit, we have already tested all num_packages.
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
}
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
env.ctx.topology.TopologyString(), pools.PinString());
fprintf(stderr, "TestTiny: %s %s\n", env.ctx.topology.TopologyString(),
pools.PinString());
pools.MaybeStartSpinning(threading_args.spin);
for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += max_packages * 4) {
for (size_t N = 4; N <= 64; N += 4) {
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, F32>(M, K, N, /*add=*/false, env, __LINE__);
@ -305,7 +297,6 @@ void TestTiny() {
}
}
pools.MaybeStopSpinning(threading_args.spin);
}
}
void TestAllMatMul() {

View File

@ -160,11 +160,10 @@ Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
// - supported by the OS (currently Linux only),
// - the page size is known and 'reasonably small', preferably less than
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
// - we successfully detected topology and there are multiple nodes;
// - there are multiple packages, because we shard by package_idx.
// - we successfully detected topology and there are multiple nodes.
if constexpr (GEMMA_BIND) {
if ((base_page_bytes_ != 0 && base_page_bytes_ <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
topology.NumNodes() > 1) {
if (enable_bind) {
// Ensure pages meet the alignment requirements of `AllocBytes`.
HWY_ASSERT(base_page_bytes_ >= quantum_bytes_);

View File

@ -149,7 +149,7 @@ class Allocator {
}
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
// control over memory placement and multiple packages and NUMA nodes.
// control over memory placement and multiple NUMA nodes.
bool ShouldBind() const { return should_bind_; }
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is

View File

@ -30,13 +30,6 @@
namespace gcpp {
// Maximum number of packages (CPU sockets) to use. `ThreadingArgs` verifies the
// runtime `max_packages` does not exceed this. MatMul's outer per-package loop
// is disabled if this is 1.
HWY_INLINE_VAR constexpr size_t kMaxPackages = 1;
HWY_INLINE_VAR constexpr size_t kMaxClusters = 128; // TODO: shrink
// TODO: extend to 16k after updating non_eos.
HWY_INLINE_VAR constexpr size_t kMaxBatchSize = 4096;

View File

@ -455,6 +455,7 @@ class MatOwner {
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
MatStorageT() = default; // for std::vector in Activations.
MatStorageT(const char* name, Extents2D extents, const Allocator& allocator,
MatPadding padding)
: MatPtrT<MatT>(name, extents) {

View File

@ -25,7 +25,7 @@
// IWYU pragma: begin_exports
#include "util/allocator.h"
#include "util/args.h"
#include "util/basics.h" // Tristate, kMaxPackages
#include "util/basics.h" // Tristate
#include "util/threading.h"
#include "util/topology.h"
#include "hwy/profiler.h"
@ -41,7 +41,7 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
// For BoundedTopology:
size_t skip_packages;
size_t max_packages;
size_t max_packages = 1;
size_t skip_clusters;
size_t max_clusters;
size_t skip_lps;
@ -58,13 +58,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
void ForEach(const Visitor& visitor) {
// These can be used to partition CPU packages/sockets and their
// clusters/CCXs across several program instances. The default is to use
// all available resources on one package. Note that `kMaxPackages` is an
// upper bound on `max_packages`.
// all available resources on the first package.
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{1},
"Max sockets to use; default = 1, 0 = unlimited.", 2);
HWY_ASSERT(max_packages <= kMaxPackages);
visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
@ -105,7 +101,7 @@ struct ThreadingContext {
hwy::Profiler& profiler;
// Detects topology, subject to limits imposed by user-specified `args`.
// For example, if `args.max_packages` is 1, then `topology.NumPackages()`
// For example, if `args.max_clusters` is 1, then `topology.NumClusters()`
// will be 1 regardless of the actual system topology.
BoundedTopology topology;