mirror of https://github.com/google/gemma.cpp.git
Add pin flag to disable pinning. Refs #338
PiperOrigin-RevId: 661389171
This commit is contained in:
parent
fd1b0743a7
commit
282f73ec2f
|
|
@ -43,7 +43,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads);
|
||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
// Note that num_threads is an upper bound; we also limit to the number of
|
||||
// detected and enabled cores.
|
||||
PerClusterPools pools(app.max_clusters, app.num_threads);
|
||||
PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||
|
||||
Gemma model = CreateGemma(loader, pools);
|
||||
KVCache kv_cache =
|
||||
|
|
|
|||
|
|
@ -54,10 +54,12 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
public:
|
||||
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
Path log; // output
|
||||
int verbosity;
|
||||
|
||||
size_t num_threads; // divided among the detected clusters
|
||||
size_t max_clusters;
|
||||
int pin; // -1 = auto, 0 = no, 1 = yes
|
||||
|
||||
std::string eot_line;
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -67,10 +69,13 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
|
||||
visitor(num_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
|
||||
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
|
|
|
|||
|
|
@ -120,7 +120,9 @@ class PerClusterPools {
|
|||
// result in threads not running on their own core, we only allow for
|
||||
// *upper bounds* on the number of clusters and threads. The actual number of
|
||||
// clusters and threads are still limited by the detected topology.
|
||||
PerClusterPools(size_t max_clusters, size_t max_threads)
|
||||
//
|
||||
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
|
||||
PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1)
|
||||
: have_threading_support_(hwy::HaveThreadingSupport()),
|
||||
cores_per_cluster_(DetectCoresPerCluster()),
|
||||
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
|
||||
|
|
@ -131,9 +133,11 @@ class PerClusterPools {
|
|||
// the first N processors, which are typically on the first socket.
|
||||
const size_t num_threads =
|
||||
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
|
||||
fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads);
|
||||
if (pin == -1) pin = num_threads > 8;
|
||||
fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n",
|
||||
num_threads, pin);
|
||||
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
||||
if (num_threads > 1) {
|
||||
if (num_threads > 1 && pin) {
|
||||
inner_pools_.back()->Run(0, num_threads,
|
||||
[](uint64_t /*task*/, size_t thread) {
|
||||
hwy::PinThreadToLogicalProcessor(thread);
|
||||
|
|
@ -149,25 +153,31 @@ class PerClusterPools {
|
|||
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
||||
}
|
||||
|
||||
// For each inner pool, pin their threads AND the associated outer thread
|
||||
// (the one calling inner.Run()) to the enabled cores in the cluster.
|
||||
outer_pool_.Run(
|
||||
0, outer_pool_.NumWorkers(),
|
||||
[this](uint64_t outer, size_t outer_thread) {
|
||||
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||
hwy::ThreadPool& inner = *inner_pools_[outer];
|
||||
if (pin == -1) {
|
||||
pin = (outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers()) >= 12;
|
||||
}
|
||||
|
||||
const std::vector<size_t> cores =
|
||||
CoresInLPS(cores_per_cluster_[outer]);
|
||||
// May have been capped by max_threads.
|
||||
HWY_ASSERT(inner.NumWorkers() <= cores.size());
|
||||
if (pin) {
|
||||
// For each inner pool, pin their threads AND the associated outer thread
|
||||
// (the one calling inner.Run()) to the enabled cores in the cluster.
|
||||
outer_pool_.Run(
|
||||
0, outer_pool_.NumWorkers(),
|
||||
[this](uint64_t outer, size_t outer_thread) {
|
||||
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||
hwy::ThreadPool& inner = *inner_pools_[outer];
|
||||
|
||||
inner.Run(0, inner.NumWorkers(),
|
||||
[&cores](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each inner has one task
|
||||
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||
});
|
||||
});
|
||||
const std::vector<size_t> cores =
|
||||
CoresInLPS(cores_per_cluster_[outer]);
|
||||
// May have been capped by max_threads.
|
||||
HWY_ASSERT(inner.NumWorkers() <= cores.size());
|
||||
|
||||
inner.Run(0, inner.NumWorkers(),
|
||||
[&cores](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each inner has one task
|
||||
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Spinning reduces the latency of barrier synchronization, but wastes lots of
|
||||
|
|
|
|||
Loading…
Reference in New Issue