Create separate MMStorage objects per cluster.

PiperOrigin-RevId: 802588625
This commit is contained in:
Marie White 2025-09-03 09:35:13 -07:00 committed by Copybara-Service
parent b7b3d353db
commit 74ffe079c4
3 changed files with 17 additions and 3 deletions

View File

@ -599,7 +599,8 @@ class MMPerPackage {
} else {
// Always decompress. To reduce code size/compile time, we no longer
// support a separate F32 kernel; most A are already BF16.
const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents());
const StridedViewBF A_view =
args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents());
DecompressA<MMParallelPolicyT>(A, A_view);
DispatchOrder(parallel_policy, A_view, B, C_rows);
}

View File

@ -402,7 +402,19 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages,
NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages));
}
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) {
MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) {
// Create storage per cluster. This only applies to in-cluster parallelism.
// For nested and sequential parallelism, a single MMStorage is used.
size_t num_packages = ctx.topology.NumPackages();
size_t num_clusters = 0;
for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) {
num_clusters += ctx.topology.NumClusters(pkg_idx);
}
storage.reserve(num_clusters);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
storage.push_back(MMStorage(ctx));
}
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);

View File

@ -369,6 +369,7 @@ class MMStorage {
StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxBatchSize);
HWY_DASSERT(extents.cols <= kMaxK);
HWY_DASSERT(pkg_A_[pkg_idx] != nullptr);
return StridedViewBF(const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
extents.cols, pkg_A_[pkg_idx]->Stride());
}
@ -733,7 +734,7 @@ struct MatMulEnv {
// Whether to print the best config immediately after autotuning finished.
bool print_best = false;
MMStorage storage;
std::vector<MMStorage> storage;
MMKeys keys;
std::vector<MMPerKey> per_key;