diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index a9685e2..152e6ce 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -599,7 +599,8 @@ class MMPerPackage { } else { // Always decompress. To reduce code size/compile time, we no longer // support a separate F32 kernel; most A are already BF16. - const StridedViewBF A_view = args_.env->storage.A(pkg_idx_, A.Extents()); + const StridedViewBF A_view = + args_.env->storage[cluster_idx_].A(pkg_idx_, A.Extents()); DecompressA(A, A_view); DispatchOrder(parallel_policy, A_view, B, C_rows); } diff --git a/ops/matmul.cc b/ops/matmul.cc index 812fe99..83fc036 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -402,7 +402,19 @@ IndexRangePartition MMRangesOfNP(ThreadingContext& ctx, size_t max_packages, NPMultiple(ctx.allocator, N, sizeof_TC, nr, num_packages)); } -MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx), storage(ctx) { +MatMulEnv::MatMulEnv(ThreadingContext& ctx) : ctx(ctx) { + // Create storage per cluster. This only applies to in-cluster parallelism. + // For nested and sequential parallelism, a single MMStorage is used. + size_t num_packages = ctx.topology.NumPackages(); + size_t num_clusters = 0; + for (size_t pkg_idx = 0; pkg_idx < num_packages; ++pkg_idx) { + num_clusters += ctx.topology.NumClusters(pkg_idx); + } + storage.reserve(num_clusters); + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + storage.push_back(MMStorage(ctx)); + } + char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); diff --git a/ops/matmul.h b/ops/matmul.h index 70c7d20..e76d37b 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -369,6 +369,7 @@ class MMStorage { StridedViewBF A(size_t pkg_idx, const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxBatchSize); HWY_DASSERT(extents.cols <= kMaxK); + HWY_DASSERT(pkg_A_[pkg_idx] != nullptr); return StridedViewBF(const_cast(pkg_A_[pkg_idx]->Row(0)), extents.cols, pkg_A_[pkg_idx]->Stride()); } @@ -733,7 +734,7 @@ struct MatMulEnv { // Whether to print the best config immediately after autotuning finished. bool print_best = false; - MMStorage storage; + std::vector storage; MMKeys keys; std::vector per_key;