mirror of https://github.com/google/gemma.cpp.git
Add pin flag to disable pinning. Refs #338
PiperOrigin-RevId: 661389171
This commit is contained in:
parent
fd1b0743a7
commit
282f73ec2f
|
|
@ -43,7 +43,7 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads);
|
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||||
gcpp::KVCache kv_cache =
|
gcpp::KVCache kv_cache =
|
||||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
|
|
||||||
// Note that num_threads is an upper bound; we also limit to the number of
|
// Note that num_threads is an upper bound; we also limit to the number of
|
||||||
// detected and enabled cores.
|
// detected and enabled cores.
|
||||||
PerClusterPools pools(app.max_clusters, app.num_threads);
|
PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||||
|
|
||||||
Gemma model = CreateGemma(loader, pools);
|
Gemma model = CreateGemma(loader, pools);
|
||||||
KVCache kv_cache =
|
KVCache kv_cache =
|
||||||
|
|
|
||||||
|
|
@ -54,10 +54,12 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
public:
|
public:
|
||||||
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
|
|
||||||
Path log; // output
|
|
||||||
int verbosity;
|
int verbosity;
|
||||||
|
|
||||||
size_t num_threads; // divided among the detected clusters
|
size_t num_threads; // divided among the detected clusters
|
||||||
size_t max_clusters;
|
size_t max_clusters;
|
||||||
|
int pin; // -1 = auto, 0 = no, 1 = yes
|
||||||
|
|
||||||
std::string eot_line;
|
std::string eot_line;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
|
|
@ -67,10 +69,13 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||||
"developer/debug info).\n Default = 1.",
|
"developer/debug info).\n Default = 1.",
|
||||||
2);
|
2);
|
||||||
|
|
||||||
visitor(num_threads, "num_threads", size_t{0},
|
visitor(num_threads, "num_threads", size_t{0},
|
||||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_clusters, "max_clusters", size_t{0},
|
visitor(max_clusters, "max_clusters", size_t{0},
|
||||||
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
|
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
|
||||||
|
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||||
|
|
||||||
visitor(
|
visitor(
|
||||||
eot_line, "eot_line", std::string(""),
|
eot_line, "eot_line", std::string(""),
|
||||||
"End of turn line. "
|
"End of turn line. "
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,9 @@ class PerClusterPools {
|
||||||
// result in threads not running on their own core, we only allow for
|
// result in threads not running on their own core, we only allow for
|
||||||
// *upper bounds* on the number of clusters and threads. The actual number of
|
// *upper bounds* on the number of clusters and threads. The actual number of
|
||||||
// clusters and threads are still limited by the detected topology.
|
// clusters and threads are still limited by the detected topology.
|
||||||
PerClusterPools(size_t max_clusters, size_t max_threads)
|
//
|
||||||
|
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
|
||||||
|
PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1)
|
||||||
: have_threading_support_(hwy::HaveThreadingSupport()),
|
: have_threading_support_(hwy::HaveThreadingSupport()),
|
||||||
cores_per_cluster_(DetectCoresPerCluster()),
|
cores_per_cluster_(DetectCoresPerCluster()),
|
||||||
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
|
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
|
||||||
|
|
@ -131,9 +133,11 @@ class PerClusterPools {
|
||||||
// the first N processors, which are typically on the first socket.
|
// the first N processors, which are typically on the first socket.
|
||||||
const size_t num_threads =
|
const size_t num_threads =
|
||||||
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
|
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
|
||||||
fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads);
|
if (pin == -1) pin = num_threads > 8;
|
||||||
|
fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n",
|
||||||
|
num_threads, pin);
|
||||||
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
||||||
if (num_threads > 1) {
|
if (num_threads > 1 && pin) {
|
||||||
inner_pools_.back()->Run(0, num_threads,
|
inner_pools_.back()->Run(0, num_threads,
|
||||||
[](uint64_t /*task*/, size_t thread) {
|
[](uint64_t /*task*/, size_t thread) {
|
||||||
hwy::PinThreadToLogicalProcessor(thread);
|
hwy::PinThreadToLogicalProcessor(thread);
|
||||||
|
|
@ -149,25 +153,31 @@ class PerClusterPools {
|
||||||
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each inner pool, pin their threads AND the associated outer thread
|
if (pin == -1) {
|
||||||
// (the one calling inner.Run()) to the enabled cores in the cluster.
|
pin = (outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers()) >= 12;
|
||||||
outer_pool_.Run(
|
}
|
||||||
0, outer_pool_.NumWorkers(),
|
|
||||||
[this](uint64_t outer, size_t outer_thread) {
|
|
||||||
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
|
||||||
hwy::ThreadPool& inner = *inner_pools_[outer];
|
|
||||||
|
|
||||||
const std::vector<size_t> cores =
|
if (pin) {
|
||||||
CoresInLPS(cores_per_cluster_[outer]);
|
// For each inner pool, pin their threads AND the associated outer thread
|
||||||
// May have been capped by max_threads.
|
// (the one calling inner.Run()) to the enabled cores in the cluster.
|
||||||
HWY_ASSERT(inner.NumWorkers() <= cores.size());
|
outer_pool_.Run(
|
||||||
|
0, outer_pool_.NumWorkers(),
|
||||||
|
[this](uint64_t outer, size_t outer_thread) {
|
||||||
|
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||||
|
hwy::ThreadPool& inner = *inner_pools_[outer];
|
||||||
|
|
||||||
inner.Run(0, inner.NumWorkers(),
|
const std::vector<size_t> cores =
|
||||||
[&cores](uint64_t task, size_t thread) {
|
CoresInLPS(cores_per_cluster_[outer]);
|
||||||
HWY_ASSERT(task == thread); // each inner has one task
|
// May have been capped by max_threads.
|
||||||
hwy::PinThreadToLogicalProcessor(cores[task]);
|
HWY_ASSERT(inner.NumWorkers() <= cores.size());
|
||||||
});
|
|
||||||
});
|
inner.Run(0, inner.NumWorkers(),
|
||||||
|
[&cores](uint64_t task, size_t thread) {
|
||||||
|
HWY_ASSERT(task == thread); // each inner has one task
|
||||||
|
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Spinning reduces the latency of barrier synchronization, but wastes lots of
|
// Spinning reduces the latency of barrier synchronization, but wastes lots of
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue