Add pin flag to disable pinning. Refs #338

PiperOrigin-RevId: 661389171
This commit is contained in:
Jan Wassenberg 2024-08-09 13:44:42 -07:00 committed by Copybara-Service
parent fd1b0743a7
commit 282f73ec2f
4 changed files with 38 additions and 23 deletions

View File

@ -43,7 +43,7 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads); gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache = gcpp::KVCache kv_cache =
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);

View File

@ -166,7 +166,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
// Note that num_threads is an upper bound; we also limit to the number of // Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores. // detected and enabled cores.
PerClusterPools pools(app.max_clusters, app.num_threads); PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
Gemma model = CreateGemma(loader, pools); Gemma model = CreateGemma(loader, pools);
KVCache kv_cache = KVCache kv_cache =

View File

@ -54,10 +54,12 @@ class AppArgs : public ArgsBase<AppArgs> {
public: public:
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
Path log; // output
int verbosity; int verbosity;
size_t num_threads; // divided among the detected clusters size_t num_threads; // divided among the detected clusters
size_t max_clusters; size_t max_clusters;
int pin; // -1 = auto, 0 = no, 1 = yes
std::string eot_line; std::string eot_line;
template <class Visitor> template <class Visitor>
@ -67,10 +69,13 @@ class AppArgs : public ArgsBase<AppArgs> {
"output\n 1 = standard user-facing terminal ui\n 2 = show " "output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.", "developer/debug info).\n Default = 1.",
2); 2);
visitor(num_threads, "num_threads", size_t{0}, visitor(num_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2); "Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0}, visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2); "Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor( visitor(
eot_line, "eot_line", std::string(""), eot_line, "eot_line", std::string(""),
"End of turn line. " "End of turn line. "

View File

@ -120,7 +120,9 @@ class PerClusterPools {
// result in threads not running on their own core, we only allow for // result in threads not running on their own core, we only allow for
// *upper bounds* on the number of clusters and threads. The actual number of // *upper bounds* on the number of clusters and threads. The actual number of
// clusters and threads are still limited by the detected topology. // clusters and threads are still limited by the detected topology.
PerClusterPools(size_t max_clusters, size_t max_threads) //
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1)
: have_threading_support_(hwy::HaveThreadingSupport()), : have_threading_support_(hwy::HaveThreadingSupport()),
cores_per_cluster_(DetectCoresPerCluster()), cores_per_cluster_(DetectCoresPerCluster()),
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) { outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
@ -131,9 +133,11 @@ class PerClusterPools {
// the first N processors, which are typically on the first socket. // the first N processors, which are typically on the first socket.
const size_t num_threads = const size_t num_threads =
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads); CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads); if (pin == -1) pin = num_threads > 8;
fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n",
num_threads, pin);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads)); inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
if (num_threads > 1) { if (num_threads > 1 && pin) {
inner_pools_.back()->Run(0, num_threads, inner_pools_.back()->Run(0, num_threads,
[](uint64_t /*task*/, size_t thread) { [](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(thread); hwy::PinThreadToLogicalProcessor(thread);
@ -149,25 +153,31 @@ class PerClusterPools {
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads)); inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
} }
// For each inner pool, pin their threads AND the associated outer thread if (pin == -1) {
// (the one calling inner.Run()) to the enabled cores in the cluster. pin = (outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers()) >= 12;
outer_pool_.Run( }
0, outer_pool_.NumWorkers(),
[this](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
hwy::ThreadPool& inner = *inner_pools_[outer];
const std::vector<size_t> cores = if (pin) {
CoresInLPS(cores_per_cluster_[outer]); // For each inner pool, pin their threads AND the associated outer thread
// May have been capped by max_threads. // (the one calling inner.Run()) to the enabled cores in the cluster.
HWY_ASSERT(inner.NumWorkers() <= cores.size()); outer_pool_.Run(
0, outer_pool_.NumWorkers(),
[this](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
hwy::ThreadPool& inner = *inner_pools_[outer];
inner.Run(0, inner.NumWorkers(), const std::vector<size_t> cores =
[&cores](uint64_t task, size_t thread) { CoresInLPS(cores_per_cluster_[outer]);
HWY_ASSERT(task == thread); // each inner has one task // May have been capped by max_threads.
hwy::PinThreadToLogicalProcessor(cores[task]); HWY_ASSERT(inner.NumWorkers() <= cores.size());
});
}); inner.Run(0, inner.NumWorkers(),
[&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
}
} }
// Spinning reduces the latency of barrier synchronization, but wastes lots of // Spinning reduces the latency of barrier synchronization, but wastes lots of