Allow overriding num threads despite detecting topology

PiperOrigin-RevId: 720188756
This commit is contained in:
Jan Wassenberg 2025-01-27 08:57:08 -08:00 committed by Copybara-Service
parent e997468496
commit a248f76245
2 changed files with 43 additions and 39 deletions

View File

@ -45,35 +45,48 @@ class Pinning {
return false; }
public:
// Returns set of LPs available for use. Subsequent calls return the same
// set as the first, because pinning overwrites the main thread's affinity.
// Returns set of LPs available for use. Cached during the first call
// because subsequent pinning overwrites the main thread's affinity.
// Thread-hostile, not called concurrently.
LPS EnabledLPs() {
if (original_affinity_.Any()) return original_affinity_;
LPS EnabledLPs(const BoundedSlice& lp_slice) {
if (enabled_lps_.Any()) return enabled_lps_;
// Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
LPS enabled_lps;
if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
LPS affinity;
if (HWY_LIKELY(GetThreadAffinity(affinity))) {
// To honor taskset/numactl *and* the users's `lp_slice`, we interpret
// the latter as a slice of the 1-bits of `enabled_lps`. Note that this
// can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = affinity.Count();
size_t enabled_idx = 0;
affinity.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx)) {
enabled_lps_.Set(lp);
}
++enabled_idx;
});
} else {
const size_t num_lps = hwy::TotalLogicalProcessors();
HWY_WARN("unknown OS affinity, considering all %zu LPs enabled.",
num_lps);
HWY_WARN("unknown OS affinity, max %zu LPs and slice %zu.", num_lps,
lp_slice.Num(num_lps));
for (size_t lp = 0; lp < num_lps; ++lp) {
enabled_lps.Set(lp);
if (lp_slice.Contains(num_lps, lp)) {
enabled_lps_.Set(lp);
}
}
}
// Without threading support, only keep the first enabled LP; it might still
// make sense to pin the main thread to avoid migrations.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any());
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
HWY_ASSERT(enabled_lps_.Any());
const size_t lp = enabled_lps_.First();
enabled_lps_ = LPS();
enabled_lps_.Set(lp);
HWY_WARN("Warning, threads not supported, using only the main thread.");
}
original_affinity_ = enabled_lps;
return enabled_lps;
return enabled_lps_;
}
void SetPolicy(Tristate pin) {
@ -128,7 +141,7 @@ class Pinning {
private:
std::atomic_flag any_error_ = ATOMIC_FLAG_INIT;
bool want_pin_; // set in SetPolicy
LPS original_affinity_;
LPS enabled_lps_;
}; // Pinning
// Singleton saves global affinity across all BoundedTopology instances because
@ -141,7 +154,7 @@ static Pinning& GetPinning() {
BoundedTopology::BoundedTopology(BoundedSlice package_slice,
BoundedSlice cluster_slice,
BoundedSlice lp_slice) {
const LPS enabled_lps = GetPinning().EnabledLPs();
const LPS enabled_lps = GetPinning().EnabledLPs(lp_slice);
#if !GEMMA_DISABLE_TOPOLOGY
if (HWY_LIKELY(!topology_.packages.empty())) {
@ -152,7 +165,7 @@ BoundedTopology::BoundedTopology(BoundedSlice package_slice,
// Topology unknown or no packages with enabled LPs: create a single
// package with one cluster, and one node.
if (HWY_UNLIKELY(NumPackages() == 0)) {
InitFromSlice(enabled_lps, lp_slice);
InitFromLPs(enabled_lps);
}
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
@ -214,9 +227,9 @@ constexpr bool kSplitLargeClusters = false;
constexpr size_t kMaxClusters = 8;
constexpr size_t kMaxLPsPerCluster = 6;
// Topology is unknown, rely on OS affinity and user-specified slice.
BoundedTopology::Package::Package(const LPS& enabled_lps,
BoundedSlice lp_slice) {
// Topology is unknown, use only the given LPs which derive from OS affinity
// and `lp_slice`.
BoundedTopology::Package::Package(const LPS& enabled_lps) {
LPS clusters_lps[kMaxClusters];
const size_t num_clusters =
kSplitLargeClusters
@ -224,16 +237,9 @@ BoundedTopology::Package::Package(const LPS& enabled_lps,
hwy::DivCeil(enabled_lps.Count(), kMaxLPsPerCluster))
: 1;
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx)) {
clusters_lps[enabled_idx % num_clusters].Set(lp);
}
++enabled_idx;
});
@ -386,9 +392,8 @@ void BoundedTopology::InitFromTopology(const LPS& enabled_lps,
#endif // !GEMMA_DISABLE_TOPOLOGY
void BoundedTopology::InitFromSlice(const LPS& enabled_lps,
BoundedSlice lp_slice) {
packages_.push_back(Package(enabled_lps, lp_slice));
void BoundedTopology::InitFromLPs(const LPS& enabled_lps) {
packages_.push_back(Package(enabled_lps));
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
GetCluster(0, 0).Size());
@ -433,7 +438,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
0, all_packages_->NumWorkers(), [&](uint64_t pkg_idx, size_t thread) {
HWY_ASSERT(pkg_idx == thread); // each thread has one task
packages_[pkg_idx] =
Package(topology_, pkg_idx, max_workers_per_package, lp_slice);
Package(topology_, pkg_idx, max_workers_per_package);
});
all_pinned_ = GetPinning().AllPinned(&pin_string_);
@ -454,8 +459,7 @@ NestedPools::NestedPools(size_t max_threads, Tristate pin,
}
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package,
BoundedSlice lp_slice) {
size_t max_workers_per_package) {
// Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(pkg_idx));
const size_t max_workers_per_cluster =

View File

@ -165,7 +165,7 @@ class BoundedTopology {
private:
struct Package {
Package(const LPS& enabled_lps, BoundedSlice lp_slice);
explicit Package(const LPS& enabled_lps);
Package(const LPS& enabled_lps, const hwy::Topology& topology,
size_t pkg_idx, BoundedSlice cluster_slice);
@ -177,7 +177,7 @@ class BoundedTopology {
void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice,
BoundedSlice cluster_slice);
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice);
void InitFromLPs(const LPS& enabled_lps);
#if !GEMMA_DISABLE_TOPOLOGY
hwy::Topology topology_;
@ -304,7 +304,7 @@ class NestedPools {
public:
Package() = default; // for vector
Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package, BoundedSlice lp_slice);
size_t max_workers_per_package);
size_t NumClusters() const { return clusters_.size(); }
size_t MaxWorkersPerCluster() const {