diff --git a/BUILD.bazel b/BUILD.bazel index 74f472f..f482e56 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -114,6 +114,7 @@ cc_library( "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:profiler", + "@highway//:thread_pool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 46242f6..5dc4e11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if diff --git a/MODULE.bazel b/MODULE.bazel index b6b5f78..e0ba1c7 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -18,7 +18,7 @@ bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "1d16731233de45a365b43867f27d0a5f73925300", + commit = "9781a1698ee0756ef1eaaf96930113ed7cb6d3ee", remote = "https://github.com/google/highway", ) diff --git a/README.md b/README.md index 2963bf6..722c2a8 100644 --- a/README.md +++ b/README.md @@ -452,7 +452,7 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 92d327e841d78e11ae888757a3e16d291951cf64) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_MakeAvailable(highway) ``` diff --git a/compression/types.h b/compression/types.h index 661bc42..c3be52a 100644 --- a/compression/types.h +++ b/compression/types.h @@ -45,10 +45,11 @@ namespace gcpp { // as NEON_WITHOUT_AES. Also skip SVE because SVE2_128 and SVE_256 cover most. #define GEMMA_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON | HWY_SVE) #elif HWY_ARCH_X86 -// Skip anything older than Haswell (2013); also use Zen4 for recent CPUs, -// because we do not use anything added by SPR (e.g. FP16) nor AVX 10.2. +// Skip anything older than Haswell (2013); use Zen4/SPR for recent CPUs. +// Although we do not use SPR's F16, Zen4 is only enabled for AMD. We do not +// yet use any AVX 10.2 features. #define GEMMA_DISABLED_TARGETS \ - (HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX3_SPR | HWY_AVX10_2) + (HWY_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | HWY_AVX10_2) #endif // HWY_ARCH_* #endif // GEMMA_DISABLED_TARGETS diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 49acb50..355f26d 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -84,7 +84,7 @@ namespace gcpp { namespace HWY_NAMESPACE { void CallSoftmax(Logits logits, hwy::Profiler& p) { - Softmax(logits, p, hwy::Profiler::Thread()); + Softmax(logits, p, hwy::Profiler::GlobalIdx()); } } // namespace HWY_NAMESPACE diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 7a63ace..65541d8 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index da111cc..710f5ee 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 1d16731233de45a365b43867f27d0a5f73925300) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9781a1698ee0756ef1eaaf96930113ed7cb6d3ee) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/api_server.cc b/gemma/api_server.cc index ea5377d..f05447b 100644 --- a/gemma/api_server.cc +++ b/gemma/api_server.cc @@ -376,8 +376,7 @@ void HandleGenerateContentStreaming(ServerState& state, const httplib::Request& // Ensure all data is sent sink.done(); - - return false; // End streaming + return false; // End streaming } catch (const std::exception& e) { json error_event = {{"error", {{"message", e.what()}}}}; diff --git a/gemma/attention.cc b/gemma/attention.cc index 576c0b7..a77021a 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -254,7 +254,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, MatMulEnv& env) { static const auto zone = env.ctx.profiler.AddZone( "Gen.Attention.ComputeQKV", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const hwy::Divisor div_qbatch(qbatch.Size()); const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); @@ -330,7 +330,7 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, MatMulEnv& env) { static const auto zone = env.ctx.profiler.AddZone( "Gen.Attention.SumHeads", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length @@ -350,7 +350,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, MatMulEnv& env, int flags) { static const auto zone = env.ctx.profiler.AddZone("Gen.Attention", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 669d7e7..bdf989a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -155,7 +155,7 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, Activations& activations, MatMulEnv& env) { static const auto zone = env.ctx.profiler.AddZone("Gen.FFW", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::GlobalIdx(), zone); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 778ecc6..c3e2bac 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -139,7 +139,7 @@ EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, size_t image_token_position = 0) { static const auto zone = ctx.profiler.AddZone("Gen.Embed", hwy::ProfilerFlags::kInclusive); - PROFILER_ZONE3(ctx.profiler, hwy::Profiler::Thread(), zone); + PROFILER_ZONE3(ctx.profiler, hwy::Profiler::GlobalIdx(), zone); // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && diff --git a/gemma/vit.cc b/gemma/vit.cc index 1910091..44b1bcb 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -335,7 +335,8 @@ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, // Apply soft embedding norm before input projection. CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { RMSNormInplace(weights_t->PackedScale1(), activations.x.Row(0), - vit_model_dim, env.ctx.profiler, hwy::Profiler::Thread()); + vit_model_dim, env.ctx.profiler, + hwy::Profiler::GlobalIdx()); }); } diff --git a/util/threading.cc b/util/threading.cc index 1001f05..9c4cfe0 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -19,7 +19,6 @@ #include #include // std::sort -#include #include #include #include @@ -29,7 +28,6 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/topology.h" -#include "hwy/profiler.h" namespace gcpp { @@ -41,85 +39,60 @@ static void SortByDescendingSize(std::vector& groups) { [](const T& a, const T& b) { return a.Size() > b.Size(); }); } -// Singleton, holds the original process affinity and the pinning status. -class Pinning { - static bool InContainer() { - return false; } +static bool InContainer() { + return false; // placeholder for container detection, do not remove +} - public: - void SetPolicy(Tristate pin) { - if (pin == Tristate::kDefault) { - // Pinning is unreliable inside containers because the hypervisor might - // periodically change our affinity mask, or other processes might also - // pin themselves to the same LPs. - pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; - } - want_pin_ = (pin == Tristate::kTrue); - any_error_.clear(); +PinningPolicy::PinningPolicy(Tristate pin) { + if (pin == Tristate::kDefault) { + // Pinning is unreliable inside containers because the hypervisor might + // periodically change our affinity mask, or other processes might also + // pin themselves to the same LPs. + pin = InContainer() ? Tristate::kFalse : Tristate::kTrue; } + want_pin_ = (pin == Tristate::kTrue); +} - // If want_pin_, tries to pin each worker in `pool` to an LP in `cluster`, - // and sets `any_error_` if any fails. - void MaybePin(const BoundedTopology& topology, size_t pkg_idx, - size_t cluster_idx, const BoundedTopology::Cluster& cluster, - hwy::ThreadPool& pool) { - const std::vector lps = cluster.LPVector(); - HWY_ASSERT(pool.NumWorkers() <= lps.size()); - pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task +// If `pinning.Want()`, tries to pin each worker in `pool` to an LP in +// `cluster`, and calls `pinning.NotifyFailed()` if any fails. +void MaybePin(const BoundedTopology& topology, size_t pkg_idx, + size_t cluster_idx, const BoundedTopology::Cluster& cluster, + PinningPolicy& pinning, hwy::ThreadPool& pool) { + const std::vector lps = cluster.LPVector(); + HWY_ASSERT(pool.NumWorkers() <= lps.size()); + pool.Run(0, pool.NumWorkers(), [&](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task - char buf[16]; // Linux limitation - const int bytes_written = snprintf( - buf, sizeof(buf), "P%zu X%02zu C%03d", - topology.SkippedPackages() + pkg_idx, - topology.SkippedClusters() + cluster_idx, static_cast(task)); - HWY_ASSERT(bytes_written < static_cast(sizeof(buf))); - hwy::SetThreadName(buf, 0); // does not support varargs + char buf[16]; // Linux limitation + const int bytes_written = snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", + topology.SkippedPackages() + pkg_idx, + topology.SkippedClusters() + cluster_idx, + static_cast(task)); + HWY_ASSERT(bytes_written < static_cast(sizeof(buf))); + hwy::SetThreadName(buf, 0); // does not support varargs - if (HWY_LIKELY(want_pin_)) { - if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { - // Apple does not support pinning, hence do not warn there. - if (!HWY_OS_APPLE) { - HWY_WARN("Pinning failed for task %d of %zu to %zu (size %zu)\n", - static_cast(task), pool.NumWorkers(), lps[task], - lps.size()); - } - (void)any_error_.test_and_set(); + if (HWY_LIKELY(pinning.Want())) { + if (HWY_UNLIKELY(!hwy::PinThreadToLogicalProcessor(lps[task]))) { + // Apple does not support pinning, hence do not warn there. + if (!HWY_OS_APPLE) { + HWY_WARN("Pinning failed for task %d of %zu to %zu (size %zu)\n", + static_cast(task), pool.NumWorkers(), lps[task], + lps.size()); } + pinning.NotifyFailed(); } - }); - } - - // Called ONCE after all MaybePin because it invalidates the error status. - bool AllPinned(const char** pin_string) { - // If !want_pin_, MaybePin will return without setting any_error_, but in - // that case we still want to return false to avoid spinning. - // .test() was only added in C++20, so we use .test_and_set() instead. - const bool all_pinned = want_pin_ && !any_error_.test_and_set(); - *pin_string = all_pinned ? "pinned" - : want_pin_ ? "pinning failed" - : "pinning skipped"; - return all_pinned; - } - - private: - std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; - bool want_pin_; // set in SetPolicy -}; // Pinning - -// Singleton saves global affinity across all BoundedTopology instances because -// pinning overwrites it. -static Pinning& GetPinning() { - static Pinning pinning; - return pinning; + } + }); } static PoolPtr MakePool(const Allocator& allocator, size_t num_workers, + hwy::PoolWorkerMapping mapping, std::optional node = std::nullopt) { // `ThreadPool` expects the number of threads to create, which is one less // than the number of workers, but avoid underflow if zero. const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1; - PoolPtr ptr = allocator.AllocClasses(1, num_threads); + PoolPtr ptr = + allocator.AllocClasses(1, num_threads, mapping); const size_t bytes = hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes()); if (node.has_value() && allocator.ShouldBind()) { @@ -142,10 +115,11 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) { NestedPools::NestedPools(const BoundedTopology& topology, const Allocator& allocator, size_t max_threads, - Tristate pin) { - GetPinning().SetPolicy(pin); + Tristate pin) + : pinning_(pin) { packages_.resize(topology.NumPackages()); - all_packages_ = MakePool(allocator, packages_.size()); + all_packages_ = + MakePool(allocator, packages_.size(), hwy::PoolWorkerMapping()); const size_t max_workers_per_package = DivideMaxAcross(max_threads, packages_.size()); // Each worker in all_packages_, including the main thread, will be the @@ -153,11 +127,11 @@ NestedPools::NestedPools(const BoundedTopology& topology, // `cluster.lps` if `pin`. all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) { HWY_ASSERT(pkg_idx == thread); // each thread has one task - packages_[pkg_idx] = - Package(topology, allocator, pkg_idx, max_workers_per_package); + packages_[pkg_idx] = Package(topology, allocator, pinning_, pkg_idx, + max_workers_per_package); }); - all_pinned_ = GetPinning().AllPinned(&pin_string_); + all_pinned_ = pinning_.AllPinned(&pin_string_); // For mapping package/cluster/thread to noncontiguous TLS indices, in case // cluster/thread counts differ. @@ -172,8 +146,6 @@ NestedPools::NestedPools(const BoundedTopology& topology, HWY_ASSERT(max_clusters_per_package_ <= 64); HWY_ASSERT(max_workers_per_cluster_ >= 1); HWY_ASSERT(max_workers_per_cluster_ <= 256); - - hwy::Profiler::Get().SetMaxThreads(MaxWorkers()); } // `max_or_zero` == 0 means no limit. @@ -182,15 +154,22 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { } NestedPools::Package::Package(const BoundedTopology& topology, - const Allocator& allocator, size_t pkg_idx, + const Allocator& allocator, + PinningPolicy& pinning, size_t pkg_idx, size_t max_workers_per_package) { // Pre-allocate because elements are set concurrently. clusters_.resize(topology.NumClusters(pkg_idx)); const size_t max_workers_per_cluster = DivideMaxAcross(max_workers_per_package, clusters_.size()); - all_clusters_ = MakePool(allocator, clusters_.size(), - topology.GetCluster(pkg_idx, 0).Node()); + const BoundedTopology::Cluster& cluster0 = topology.GetCluster(pkg_idx, 0); + // Core 0 of each cluster. The second argument is the cluster size, not + // number of clusters. We ensure that it is the same for all clusters so that + // the `GlobalIdx` computation is consistent within and across clusters. + const hwy::PoolWorkerMapping all_clusters_mapping(hwy::kAllClusters, + cluster0.Size()); + all_clusters_ = MakePool(allocator, clusters_.size(), all_clusters_mapping, + cluster0.Node()); // Parallel so we also pin the calling worker in `all_clusters` to // `cluster.lps`. all_clusters_->Run( @@ -198,12 +177,14 @@ NestedPools::Package::Package(const BoundedTopology& topology, HWY_ASSERT(cluster_idx == thread); // each thread has one task const BoundedTopology::Cluster& cluster = topology.GetCluster(pkg_idx, cluster_idx); + HWY_ASSERT(cluster.Size() == cluster0.Size()); clusters_[cluster_idx] = MakePool( allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster), + hwy::PoolWorkerMapping(cluster_idx, cluster.Size()), cluster.Node()); // Pin workers AND the calling thread from `all_clusters`. - GetPinning().MaybePin(topology, pkg_idx, cluster_idx, cluster, - *clusters_[cluster_idx]); + MaybePin(topology, pkg_idx, cluster_idx, cluster, pinning, + *clusters_[cluster_idx]); }); } diff --git a/util/threading.h b/util/threading.h index 5dde114..53795be 100644 --- a/util/threading.h +++ b/util/threading.h @@ -19,6 +19,7 @@ #include #include +#include #include // IWYU pragma: begin_exports @@ -40,6 +41,30 @@ namespace gcpp { // moving because it is a typedef to `std::unique_ptr`. using PoolPtr = AlignedClassPtr; +class PinningPolicy { + public: + explicit PinningPolicy(Tristate pin); + + bool Want() const { return want_pin_; } + void NotifyFailed() { (void)any_error_.test_and_set(); } + + // Called ONCE after all MaybePin because it invalidates the error status. + bool AllPinned(const char** pin_string) { + // If !want_pin_, MaybePin will return without setting any_error_, but in + // that case we still want to return false to avoid spinning. + // .test() was only added in C++20, so we use .test_and_set() instead. + const bool all_pinned = want_pin_ && !any_error_.test_and_set(); + *pin_string = all_pinned ? "pinned" + : want_pin_ ? "pinning failed" + : "pinning skipped"; + return all_pinned; + } + + private: + std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; + bool want_pin_; // set in SetPolicy +}; // PinningPolicy + // Creates a hierarchy of thread pools according to `BoundedTopology`: one with // a thread per enabled package; for each of those, one with a thread per // enabled cluster (CCX/shared L3), and for each of those, the remaining @@ -56,7 +81,12 @@ using PoolPtr = AlignedClassPtr; // Useful when there are tasks which should be parallelized by workers sharing a // cache, or on the same NUMA node. In both cases, individual pools have lower // barrier synchronization latency than one large pool. However, to utilize all -// cores, call sites will have to use nested parallel-for loops. +// cores, call sites will have to use nested parallel-for loops as in +// `HierarchicalParallelFor`. To allow switching modes easily, prefer using the +// `ParallelFor` abstraction in threading_context.h). +// +// Note that this was previously intended to use all cores, but we are now +// moving toward also allowing concurrent construction with subsets of cores. class NestedPools { public: // Neither move nor copy. @@ -151,7 +181,8 @@ class NestedPools { public: Package() = default; // for vector Package(const BoundedTopology& topology, const Allocator& allocator, - size_t pkg_idx, size_t max_workers_per_package); + PinningPolicy& pinning, size_t pkg_idx, + size_t max_workers_per_package); size_t NumClusters() const { return clusters_.size(); } size_t MaxWorkersPerCluster() const { @@ -184,8 +215,10 @@ class NestedPools { } private: - std::vector clusters_; + // Must be freed after `clusters_` because it reserves threads which are + // the main threads of `clusters_`. PoolPtr all_clusters_; + std::vector clusters_; }; // Package void SetWaitMode(hwy::PoolWaitMode wait_mode) { @@ -195,6 +228,7 @@ class NestedPools { } } + PinningPolicy pinning_; bool all_pinned_; const char* pin_string_; diff --git a/util/threading_context.cc b/util/threading_context.cc index 90a64d1..8ffd4db 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -21,6 +21,7 @@ #include #include "hwy/aligned_allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" #include "hwy/tests/test_util.h" // RandomState @@ -28,7 +29,11 @@ namespace gcpp { // Invokes `pool.Run` with varying task counts until auto-tuning completes, or // an upper bound just in case. -static void TunePool(hwy::ThreadPool& pool) { +static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) { + pool.SetWaitMode(wait_mode); + +// TODO(janwas): re-enable after investigating potential deadlock. +#if 0 const size_t num_workers = pool.NumWorkers(); // pool.Run would just be a serial loop without auto-tuning, so skip. if (num_workers == 1) return; @@ -69,6 +74,22 @@ static void TunePool(hwy::ThreadPool& pool) { HWY_ASSERT(total == prev_total + expected); prev_total += expected; } +#endif +} + +static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) { + TunePool(wait_mode, pools.AllPackages()); + for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) { + hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx); + TunePool(wait_mode, clusters); + + // Run in parallel because Turin CPUs have 16, and in real usage, we often + // run all at the same time. + clusters.Run(0, clusters.NumWorkers(), + [&](uint64_t cluster_idx, size_t /*thread*/) { + TunePool(wait_mode, pools.Cluster(pkg_idx, cluster_idx)); + }); + } } ThreadingContext::ThreadingContext(const ThreadingArgs& args) @@ -80,18 +101,9 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args) allocator(topology, cache_info, args.bind != Tristate::kFalse), pools(topology, allocator, args.max_threads, args.pin) { PROFILER_ZONE("Startup.ThreadingContext autotune"); - TunePool(pools.AllPackages()); - for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) { - hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx); - TunePool(clusters); - - // Run in parallel because Turin CPUs have 16, and in real usage, we often - // run all at the same time. - clusters.Run(0, clusters.NumWorkers(), - [&](uint64_t cluster_idx, size_t /*thread*/) { - TunePool(pools.Cluster(pkg_idx, cluster_idx)); - }); - } + TunePools(hwy::PoolWaitMode::kSpin, pools); + // kBlock is the default, hence set/tune it last. + TunePools(hwy::PoolWaitMode::kBlock, pools); } } // namespace gcpp diff --git a/util/threading_context.h b/util/threading_context.h index ac42526..ff4ff62 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -41,7 +41,7 @@ class ThreadingArgs : public ArgsBase { // For BoundedTopology: size_t skip_packages; - size_t max_packages = 1; + size_t max_packages = 1; // some users assign 1 to this, hence non-const. size_t skip_clusters; size_t max_clusters; size_t skip_lps;