Allow overriding num threads despite detecting topology

PiperOrigin-RevId: 720188756
This commit is contained in:
Jan Wassenberg 2025-01-27 08:57:08 -08:00 committed by Copybara-Service
parent e997468496
commit a248f76245
2 changed files with 43 additions and 39 deletions

View File

@ -45,35 +45,48 @@ class Pinning {
return false; } return false; }
public: public:
// Returns set of LPs available for use. Subsequent calls return the same // Returns set of LPs available for use. Cached during the first call
// set as the first, because pinning overwrites the main thread's affinity. // because subsequent pinning overwrites the main thread's affinity.
// Thread-hostile, not called concurrently. // Thread-hostile, not called concurrently.
LPS EnabledLPs() { LPS EnabledLPs(const BoundedSlice& lp_slice) {
if (original_affinity_.Any()) return original_affinity_; if (enabled_lps_.Any()) return enabled_lps_;
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. LPS affinity;
LPS enabled_lps; if (HWY_LIKELY(GetThreadAffinity(affinity))) {
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { // To honor taskset/numactl *and* the users's `lp_slice`, we interpret
// the latter as a slice of the 1-bits of `enabled_lps`. Note that this
// can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = affinity.Count();
size_t enabled_idx = 0;
affinity.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx)) {
enabled_lps_.Set(lp);
}
++enabled_idx;
});
} else {
const size_t num_lps = hwy::TotalLogicalProcessors(); const size_t num_lps = hwy::TotalLogicalProcessors();
HWY_WARN("unknown OS affinity, considering all %zu LPs enabled.", HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps,
num_lps); lp_slice.Num(num_lps));
for (size_t lp = 0; lp < num_lps; ++lp) { for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp); if (lp_slice.Contains(num_lps, lp)) {
enabled_lps_.Set(lp);
}
} }
} }
// Without threading support, only keep the first enabled LP; it might still // Without threading support, only keep the first enabled LP; it might still
// make sense to pin the main thread to avoid migrations. // make sense to pin the main thread to avoid migrations.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) { if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any()); HWY_ASSERT(enabled_lps_.Any());
const size_t lp = enabled_lps.First(); const size_t lp = enabled_lps_.First();
enabled_lps = LPS(); enabled_lps_ = LPS();
enabled_lps.Set(lp); enabled_lps_.Set(lp);
HWY_WARN("Warning, threads not supported, using only the main thread."); HWY_WARN("Warning, threads not supported, using only the main thread.");
} }
original_affinity_ = enabled_lps; return enabled_lps_;
return enabled_lps;
} }
void SetPolicy(Tristate pin) { void SetPolicy(Tristate pin) {
@ -128,7 +141,7 @@ class Pinning {
private: private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT; std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in SetPolicy bool want_pin_; // set in SetPolicy
LPS original_affinity_; LPS enabled_lps_;
}; // Pinning }; // Pinning
// Singleton saves global affinity across all BoundedTopology instances because // Singleton saves global affinity across all BoundedTopology instances because
@ -141,7 +154,7 @@ static Pinning& GetPinning() {
BoundedTopology::BoundedTopology(BoundedSlice package_slice, BoundedTopology::BoundedTopology(BoundedSlice package_slice,
BoundedSlice cluster_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice) { BoundedSlice lp_slice) {
const LPS enabled_lps = GetPinning().EnabledLPs(); const LPS enabled_lps = GetPinning().EnabledLPs(lp_slice);
#if !GEMMA_DISABLE_TOPOLOGY #if !GEMMA_DISABLE_TOPOLOGY
if (HWY_LIKELY(!topology_.packages.empty())) { if (HWY_LIKELY(!topology_.packages.empty())) {
@ -152,7 +165,7 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice,
// Topology unknown or no packages with enabled LPs: create a single // Topology unknown or no packages with enabled LPs: create a single
// package with one cluster, and one node. // package with one cluster, and one node.
if (HWY_UNLIKELY(NumPackages() == 0)) { if (HWY_UNLIKELY(NumPackages() == 0)) {
InitFromSlice(enabled_lps, lp_slice); InitFromLPs(enabled_lps);
} }
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0); HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
@ -214,9 +227,9 @@ constexpr bool kSplitLargeClusters = false;
constexpr size_t kMaxClusters = 8; constexpr size_t kMaxClusters = 8;
constexpr size_t kMaxLPsPerCluster = 6; constexpr size_t kMaxLPsPerCluster = 6;
// Topology is unknown, rely on OS affinity and user-specified slice. // Topology is unknown, use only the given LPs which derive from OS affinity
BoundedTopology::Package::Package(const LPS& enabled_lps, // and `lp_slice`.
BoundedSlice lp_slice) { BoundedTopology::Package::Package(const LPS& enabled_lps) {
LPS clusters_lps[kMaxClusters]; LPS clusters_lps[kMaxClusters];
const size_t num_clusters = const size_t num_clusters =
kSplitLargeClusters kSplitLargeClusters
@ -224,16 +237,9 @@ BoundedTopology::Package::Package(const LPS& enabled_lps,
hwy::DivCeil(enabled_lps.Count(), kMaxLPsPerCluster)) hwy::DivCeil(enabled_lps.Count(), kMaxLPsPerCluster))
: 1; : 1;
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0; size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) { enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx)) {
clusters_lps[enabled_idx % num_clusters].Set(lp); clusters_lps[enabled_idx % num_clusters].Set(lp);
}
++enabled_idx; ++enabled_idx;
}); });
@ -386,9 +392,8 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
#endif // !GEMMA_DISABLE_TOPOLOGY #endif // !GEMMA_DISABLE_TOPOLOGY
void BoundedTopology::InitFromSlice(const LPS& enabled_lps, void BoundedTopology::InitFromLPs(const LPS& enabled_lps) {
BoundedSlice lp_slice) { packages_.push_back(Package(enabled_lps));
packages_.push_back(Package(enabled_lps, lp_slice));
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
GetCluster(0, 0).Size()); GetCluster(0, 0).Size());
@ -433,7 +438,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
0, all_packages_->NumWorkers(), [&](uint64_t pkg_idx, size_t thread) { 0, all_packages_->NumWorkers(), [&](uint64_t pkg_idx, size_t thread) {
HWY_ASSERT(pkg_idx == thread); // each thread has one task HWY_ASSERT(pkg_idx == thread); // each thread has one task
packages_[pkg_idx] = packages_[pkg_idx] =
Package(topology_, pkg_idx, max_workers_per_package, lp_slice); Package(topology_, pkg_idx, max_workers_per_package);
}); });
all_pinned_ = GetPinning().AllPinned(&pin_string_); all_pinned_ = GetPinning().AllPinned(&pin_string_);
@ -454,8 +459,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
} }
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx, NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package, size_t max_workers_per_package) {
BoundedSlice lp_slice) {
// Pre-allocate because elements are set concurrently. // Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(pkg_idx)); clusters_.resize(topology.NumClusters(pkg_idx));
const size_t max_workers_per_cluster = const size_t max_workers_per_cluster =

View File

@ -165,7 +165,7 @@ class BoundedTopology {
private: private:
struct Package { struct Package {
Package(const LPS& enabled_lps, BoundedSlice lp_slice); explicit Package(const LPS& enabled_lps);
Package(const LPS& enabled_lps, const hwy::Topology& topology, Package(const LPS& enabled_lps, const hwy::Topology& topology,
size_t pkg_idx, BoundedSlice cluster_slice); size_t pkg_idx, BoundedSlice cluster_slice);
@ -177,7 +177,7 @@ class BoundedTopology {
void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice, void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice,
BoundedSlice cluster_slice); BoundedSlice cluster_slice);
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice); void InitFromLPs(const LPS& enabled_lps);
#if !GEMMA_DISABLE_TOPOLOGY #if !GEMMA_DISABLE_TOPOLOGY
hwy::Topology topology_; hwy::Topology topology_;
@ -304,7 +304,7 @@ class NestedPools {
public: public:
Package() = default; // for vector Package() = default; // for vector
Package(const BoundedTopology& topology, size_t pkg_idx, Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package, BoundedSlice lp_slice); size_t max_workers_per_package);
size_t NumClusters() const { return clusters_.size(); } size_t NumClusters() const { return clusters_.size(); }
size_t MaxWorkersPerCluster() const { size_t MaxWorkersPerCluster() const {