From bc0c0bac8b62f858fe9c28b4e6f3e39e4a07c5c2 Mon Sep 17 00:00:00 2001 From: Marie White Date: Fri, 29 Aug 2025 08:38:19 -0700 Subject: [PATCH] Add non-threading parallel policy. PiperOrigin-RevId: 800913294 --- ops/matmul.h | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/ops/matmul.h b/ops/matmul.h index 752bad1..dda9673 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -71,6 +71,48 @@ struct MMOptions { uint8_t cluster_idx_ = 0; }; +struct MMSequentialPolicy { + template + static void ForPkg(ThreadingContext& ctx, const size_t max_packages, + const Func& func) { + func(/*pkg_idx=*/0); + } + + template + static void ForNP(ThreadingContext& ctx, const IndexRange& range_np, + size_t nx_multiple, size_t inner_tasks, size_t pkg_idx, + const Func& func) { + HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + func(range_np, base_idx); + } + + template + static void ForRangesMC_NC(ThreadingContext& ctx, + const IndexRangePartition& ranges_mc, + const IndexRangePartition& ranges_nc, + size_t pkg_idx, const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + + for (size_t i = 0; i < ranges_mc.NumTasks(); ++i) { + const IndexRange range_mc = ranges_mc.Range(i); + for (size_t j = 0; j < ranges_nc.NumTasks(); ++j) { + const IndexRange range_nc = ranges_nc.Range(j); + func(range_mc, range_nc, base_idx); + } + } + } + + template + static void ForRangeMC(ThreadingContext& ctx, const IndexRange& range_mc, + size_t pkg_idx, const Func& func) { + const size_t base_idx = pkg_idx * ctx.pools.MaxWorkersPerPackage(); + for (uint64_t row_a = range_mc.begin(); row_a < range_mc.end(); ++row_a) { + func(row_a, base_idx); + } + } +}; + struct MMNestedParallelPolicy { template static void ForPkg(ThreadingContext& ctx, const size_t max_packages,