diff --git a/util/threading.h b/util/threading.h index 1b3ad41..14676b4 100644 --- a/util/threading.h +++ b/util/threading.h @@ -226,6 +226,8 @@ class NestedPools { BoundedSlice cluster_slice = BoundedSlice(), BoundedSlice lp_slice = BoundedSlice()); + bool AllPinned() const { return all_pinned_; } + // Subject to `use_spinning`, enables spin waits with the goal of reducing the // latency of barrier synchronization. We only spin during Generate to avoid // wasting energy during long waits. If `use_spinning` is kDefault, we first diff --git a/util/threading_test.cc b/util/threading_test.cc index 8e39616..4ca69ff 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -276,16 +276,27 @@ TEST(ThreadingTest, BenchJoin) { constexpr size_t kInputs = 1; static hwy::FuncInput inputs[kInputs]; - const auto measure = [&](hwy::ThreadPool& pool, const char* caption) { + const auto measure = [&](hwy::ThreadPool& pool, bool spin, + const char* caption) { inputs[0] = static_cast(hwy::Unpredictable1() * pool.NumWorkers()); hwy::Result results[kInputs]; hwy::Params params; params.verbose = false; params.max_evals = kMaxEvals; + + // Only spin for the duration of the benchmark to avoid wasting energy and + // interfering with the other pools. + if (spin) { + pool.SetWaitMode(hwy::PoolWaitMode::kSpin); + } const size_t num_results = Measure(&ForkJoin, reinterpret_cast(&pool), inputs, kInputs, results, params); + if (spin) { + pool.SetWaitMode(hwy::PoolWaitMode::kBlock); + } + for (size_t i = 0; i < num_results; ++i) { printf("%-20s: %5d: %6.2f us; MAD=%4.2f%%\n", caption, static_cast(results[i].input), @@ -303,20 +314,19 @@ TEST(ThreadingTest, BenchJoin) { }; NestedPools pools(0); - measure(pools.AllPackages(), "block packages"); + measure(pools.AllPackages(), false, "block packages"); if (pools.AllClusters(0).NumWorkers() > 1) { - measure(pools.AllClusters(0), "block clusters"); + measure(pools.AllClusters(0), false, "block clusters"); } - measure(pools.Cluster(0, 0), "block in_cluster"); + measure(pools.Cluster(0, 0), false, "block in_cluster"); - Tristate use_spinning = Tristate::kDefault; - pools.MaybeStartSpinning(use_spinning); - if (use_spinning == Tristate::kTrue) { - measure(pools.AllPackages(), "spin packages"); + if (pools.AllPinned()) { + const bool kSpin = true; + measure(pools.AllPackages(), kSpin, "spin packages"); if (pools.AllClusters(0).NumWorkers() > 1) { - measure(pools.AllClusters(0), "spin clusters"); + measure(pools.AllClusters(0), kSpin, "spin clusters"); } - measure(pools.Cluster(0, 0), "spin in_cluster"); + measure(pools.Cluster(0, 0), kSpin, "spin in_cluster"); } }