1.1x prefill speedup, revamp threading in preparation for hierarchical parallelism.

Limit thread counts to detected. Add max_clusters arg.
Update detection logic to check for smt0 - previously we pinned to some siblings.

PiperOrigin-RevId: 659755311
This commit is contained in:
Jan Wassenberg 2024-08-05 18:49:16 -07:00 committed by Copybara-Service
parent 1617e1a33d
commit 5e433e774a
12 changed files with 338 additions and 342 deletions

View File

@ -148,6 +148,16 @@ cc_library(
], ],
) )
cc_library(
name = "threading",
hdrs = ["util/threading.h"],
deps = [
"@hwy//:hwy",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)
cc_library( cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
@ -192,6 +202,7 @@ cc_library(
":tokenizer", ":tokenizer",
":kv_cache", ":kv_cache",
":weights", ":weights",
":threading",
"//compression:io", "//compression:io",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:bit_set", "@hwy//:bit_set",
@ -248,6 +259,7 @@ cc_library(
":cross_entropy", ":cross_entropy",
":gemma_lib", ":gemma_lib",
":kv_cache", ":kv_cache",
":threading",
# Placeholder for internal dep, do not remove., # Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark", "@benchmark//:benchmark",
"//compression:compress", "//compression:compress",
@ -286,10 +298,9 @@ cc_binary(
":benchmark_helper", ":benchmark_helper",
":common", ":common",
":gemma_lib", ":gemma_lib",
":threading",
# Placeholder for internal dep, do not remove., # Placeholder for internal dep, do not remove.,
"//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler", "@hwy//:profiler",
"@hwy//:thread_pool", "@hwy//:thread_pool",
], ],
@ -486,6 +497,7 @@ cc_test(
":optimizer", ":optimizer",
":prompt", ":prompt",
":sampler", ":sampler",
":threading",
":weights", ":weights",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"@hwy//:thread_pool", "@hwy//:thread_pool",

View File

@ -29,12 +29,14 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
TEST(OptimizeTest, GradientDescent) { TEST(OptimizeTest, GradientDescent) {
hwy::ThreadPool pool(0); PerClusterPools pools(1, 1);
hwy::ThreadPool& pool = pools.Inner(0);
std::mt19937 gen(42); std::mt19937 gen(42);
const ModelInfo info = { const ModelInfo info = {
@ -54,7 +56,7 @@ TEST(OptimizeTest, GradientDescent) {
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight); CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16); KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
Gemma gemma(GemmaTokenizer(), info, pool); Gemma gemma(GemmaTokenizer(), info, pools);
const auto generate = [&](const std::vector<int>& prompt) { const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply; std::vector<int> reply;

View File

@ -25,7 +25,6 @@
#include <ostream> #include <ostream>
#include <random> #include <random>
#include <string> #include <string>
#include <thread> // NOLINT
#include <utility> // std::pair #include <utility> // std::pair
#include <vector> #include <vector>
@ -37,6 +36,7 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include "util/app.h" #include "util/app.h"
#include "util/args.h" #include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -61,12 +61,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
: loader_(loader), : loader_(loader),
inference_args_(inference), inference_args_(inference),
app_(app), app_(app),
pool_(app_.num_threads) { pools_(app_.max_clusters, app_.num_threads) {
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
PinWorkersToCores(pool_);
}
AbortIfInvalidArgs(inference_args_); AbortIfInvalidArgs(inference_args_);
if (const char* err = loader_.Validate()) { if (const char* err = loader_.Validate()) {
@ -74,7 +69,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
fprintf(stderr, "Skipping model load because: %s\n", err); fprintf(stderr, "Skipping model load because: %s\n", err);
} else { } else {
fprintf(stderr, "Loading model...\n"); fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_); model_ = AllocateGemma(loader_, pools_);
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1); kv_caches_.resize(1);
@ -234,7 +229,8 @@ void LogSpeedStats(double time_start, size_t total_tokens) {
<< " [" << tok_sec << " tokens / sec" << "]\n"; << " [" << tok_sec << " tokens / sec" << "]\n";
} }
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
PerClusterPools& pools) {
loader.Print(app.verbosity); loader.Print(app.verbosity);
inference.Print(app.verbosity); inference.Print(app.verbosity);
app.Print(app.verbosity); app.Print(app.verbosity);
@ -242,22 +238,23 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
if (app.verbosity >= 2) { if (app.verbosity >= 2) {
time_t now = time(nullptr); time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
// TODO: replace hardware_concurrency with detected topology. char cpu100[100] = "unknown";
std::cout << "Date & Time : " << dt (void)hwy::platform::GetCpuString(cpu100);
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n" fprintf(stderr,
<< "Instruction set : " "Date & Time : %s" // dt includes \n
<< hwy::TargetName(hwy::DispatchedTarget()) << " (" "CPU : %s\n"
<< hwy::VectorBytes() * 8 << " bits)" << "\n"; "CPU topology : %zux%zu, using %zux%zu\n"
char cpu100[100]; "Instruction set : %s (%zu bits)\n"
if (hwy::platform::GetCpuString(cpu100)) { "Compiled config : %s\n"
std::cout << "CPU : " << cpu100 << "\n"; "Weight Type : %s\n"
} "EmbedderInput Type : %s\n",
std::cout << "Compiled config : " << CompiledConfig() << "\n" dt, cpu100, pools.CoresPerCluster().size(),
<< "Weight Type : " pools.CoresPerCluster()[0].Count(), pools.Outer().NumWorkers(),
<< StringFromType(loader.Info().weight) << "\n" pools.Inner(0).NumWorkers(),
<< "EmbedderInput Type : " hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
<< TypeName(EmbedderInputT()) << "\n"; CompiledConfig(), StringFromType(loader.Info().weight),
TypeName(EmbedderInputT()));
} }
} }

View File

@ -26,8 +26,8 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/app.h" #include "util/app.h"
#include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -98,7 +98,7 @@ class GemmaEnv {
// Controls overall behavior of the app. // Controls overall behavior of the app.
AppArgs app_; AppArgs app_;
// Thread pool for running inference. // Thread pool for running inference.
hwy::ThreadPool pool_; PerClusterPools pools_;
// Random number generator. // Random number generator.
std::mt19937 gen_; std::mt19937 gen_;
// The model to run inference on. // The model to run inference on.
@ -111,7 +111,8 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec. // Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens); void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
PerClusterPools& pools);
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
} // namespace gcpp } // namespace gcpp

View File

@ -15,6 +15,7 @@ cc_binary(
"//:args", "//:args",
"//:common", "//:common",
"//:gemma_lib", "//:gemma_lib",
"//:threading",
"//:tokenizer", "//:tokenizer",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:thread_pool", "@hwy//:thread_pool",

View File

@ -26,12 +26,14 @@
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "util/app.h" // LoaderArgs #include "util/app.h" // LoaderArgs
#include "util/args.h" #include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv); gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv); gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) { if (gcpp::HasHelp(argc, argv)) {
loader.Help(); loader.Help();
return 0; return 0;
@ -41,8 +43,8 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount()); gcpp::PerClusterPools pools(app.max_clusters, app.num_threads);
gcpp::Gemma model = gcpp::CreateGemma(loader, pool); gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache = gcpp::KVCache kv_cache =
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
size_t pos = 0; // KV Cache position size_t pos = 0; // KV Cache position

View File

@ -42,11 +42,11 @@
#include "ops/matmul-inl.h" #include "ops/matmul-inl.h"
#include "ops/matvec-inl.h" #include "ops/matvec-inl.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/bit_set.h" #include "hwy/bit_set.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/timer.h" #include "hwy/timer.h"
@ -575,143 +575,28 @@ HWY_NOINLINE void TransformerLayer(
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt // NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
// to their num_queries. // to their num_queries.
class PrefillState { class PrefillState {
// TODO: move helper functions, also those in app.h, to a threading header
using LPS = hwy::LogicalProcessorSet;
LPS Intersection(const LPS& big_set, const LPS& small_set) {
LPS both_set;
// Reduce expected work by iterating over the smaller set.
small_set.Foreach([&big_set, &both_set](size_t idx) {
if (big_set.Get(idx)) both_set.Set(idx);
});
return both_set;
}
std::vector<size_t> CoresInLPS(const LPS& cluster) {
std::vector<size_t> cores;
cores.reserve(cluster.Count());
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
return cores;
}
// For each cluster (shared L3 cache), a bitset of cores.
using CoresPerCluster = std::vector<LPS>;
// Returns empty if detection failed.
CoresPerCluster DetectClusters() {
CoresPerCluster clusters;
// Which processors are not disabled via OS, taskset, or numactl.
LPS enabled;
// If we don't know, better to use just a single inner pool rather than risk
// oversubscribing to enabled cores.
if (!GetThreadAffinity(enabled)) return clusters;
hwy::Topology topology;
if (topology.packages.empty()) return clusters;
// For each cluster = outer, the cores that will be used for an inner pool.
CoresPerCluster inner_lps;
for (const hwy::Topology::Package& package : topology.packages) {
for (const hwy::Topology::Cluster& cluster : package.clusters) {
// Only use enabled cores, and only add if not empty.
const LPS lps = Intersection(enabled, cluster.lps);
if (lps.Any()) clusters.push_back(lps);
}
}
// Sort by descending number of enabled cores, so that we preferentially
// use the largest clusters.
std::sort(clusters.begin(), clusters.end(),
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
return clusters;
}
// Returns false if the main pool should be reused instead.
bool AssignInnerPoolsToClusters(const size_t num_queries) {
HWY_ASSERT(num_queries != 0);
CoresPerCluster inner_lps = DetectClusters();
// If we have more outer workers than queries, discard the excess.
if (inner_lps.size() > num_queries) inner_lps.resize(num_queries);
// If we're not going to create multiple pools, avoid the overhead of
// re-pinning (60 ms) and reuse the main pool.
if (inner_lps.size() <= 1) return false;
// Before creating new threads, stop the old ones from spinning. Caller is
// responsible for undoing this by calling `ResumeMainSpinning`.
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
outer_pool_ = std::make_unique<hwy::ThreadPool>(inner_lps.size());
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
HWY_ASSERT(inner_pools_.empty());
for (const LPS& inner : inner_lps) {
inner_pools_.push_back(new hwy::ThreadPool(inner.Count()));
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
}
// For each inner pool, pin their threads AND the associated outer thread
// to the enabled cores in the cluster.
outer_pool_->Run(
0, inner_lps.size(),
[this, &inner_lps](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
const std::vector<size_t> cores = CoresInLPS(inner_lps[outer]);
inner_pools_[outer]->Run(
0, cores.size(), [&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
return true;
}
void ReuseMainPoolAsInner() {
// Still allocate an empty pool to simplify Prefill().
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
HWY_ASSERT(inner_pools_.empty());
inner_pools_.push_back(main_pool_);
}
public: public:
// Creates pools. AllocateActivations must still be called separately; it has
// a template argument.
PrefillState(hwy::ThreadPool& main_pool, size_t num_queries)
: main_pool_(&main_pool) {
PROFILER_ZONE("Init.Prefill.Ctor");
if (!AssignInnerPoolsToClusters(num_queries)) {
ReuseMainPoolAsInner();
}
}
~PrefillState() {
for (hwy::ThreadPool* p : inner_pools_) {
if (p != main_pool_) delete p;
}
}
// `tbatch_size` is the number of tokens from one query to prefill at a time. // `tbatch_size` is the number of tokens from one query to prefill at a time.
template <class TConfig> template <class TConfig>
void AllocateActivations(size_t num_queries, size_t tbatch_size) { void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) {
PROFILER_ZONE("Init.Prefill.AllocateActivations"); PROFILER_ZONE("Init.Prefill");
HWY_ASSERT(num_queries != 0);
const size_t outer_workers = outer_pool_->NumWorkers();
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
HWY_ASSERT(activations_.empty()); // only call once. HWY_ASSERT(activations_.empty()); // only call once.
activations_.resize(outer_workers);
if (outer_workers == 1) { // Allocate one activation per query, not outer worker, because the common
// case is a single query. If we allocate the lesser of the two, it is
// unclear how to choose an unused activation in Prefill.
activations_.resize(num_queries);
if (num_queries == 1) {
activations_[0].Allocate<TConfig>(tbatch_size); activations_[0].Allocate<TConfig>(tbatch_size);
} else { } else {
// Allocating in parallel can save 30 ms. // Allocating in parallel can save 30 ms. We might have more workers than
main_pool_->Run(0, outer_workers, // queries/tasks, so do not check the `thread` argument.
[this, tbatch_size](uint64_t task, size_t /*thread*/) { pools.Outer().Run(0, num_queries,
activations_[task].Allocate<TConfig>(tbatch_size); [this, tbatch_size](uint64_t qi, size_t /*thread*/) {
}); activations_[qi].Allocate<TConfig>(tbatch_size);
});
} }
} }
@ -721,7 +606,7 @@ class PrefillState {
const size_t query_idx_start, const size_t query_idx_start,
const CompressedWeights<TConfig>& weights, const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const KVCaches& kv_caches) { const KVCaches& kv_caches, PerClusterPools& pools) {
PROFILER_ZONE("Gen.Prefill"); PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = prompts.size(); const size_t num_queries = prompts.size();
HWY_ASSERT(kv_caches.size() == num_queries); HWY_ASSERT(kv_caches.size() == num_queries);
@ -729,10 +614,10 @@ class PrefillState {
// For each query (parallel): an outer worker processes all its tokens. // For each query (parallel): an outer worker processes all its tokens.
// `qi` is relative to the batch, not the global query index. // `qi` is relative to the batch, not the global query index.
outer_pool_->Run( pools.Outer().Run(
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR { 0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
Activations& activations = activations_[qthread]; Activations& activations = activations_[qi];
hwy::ThreadPool& inner_pool = *inner_pools_[qthread]; hwy::ThreadPool& inner_pool = pools.Inner(qthread);
// Single query at a time, so pass a slice of the KV cache because // Single query at a time, so pass a slice of the KV cache because
// GemmaAttention will only access the first. // GemmaAttention will only access the first.
@ -768,29 +653,8 @@ class PrefillState {
}); });
} }
// Stops spinning in our pools and resume spinning in main_pool_.
void ResumeMainSpinning() {
// If we didn't create a new inner pool, we didn't stop spinning on the
// main pool, so nothing to do here.
if (inner_pools_[0] == main_pool_) return;
for (hwy::ThreadPool* p : inner_pools_) {
p->SetWaitMode(hwy::PoolWaitMode::kBlock);
}
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
}
private: private:
hwy::ThreadPool* main_pool_; std::vector<Activations> activations_; // One per query, filled by Init.
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
// Holds a single pointer equal to main_pool_, or new allocations; in either
// case, size() is equal to outer_pool_->NumWorkers(). The first case avoids
// allocation overhead for the common case of a single query.
std::vector<hwy::ThreadPool*> inner_pools_;
// size() == outer_pool_->NumWorkers(); filled by AllocateActivations.
std::vector<Activations> activations_;
}; };
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode, // `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
@ -945,12 +809,15 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, const size_t pos, const MultiplePromptsTokens& prompts, const size_t pos,
const size_t query_idx_start, const KVCaches& kv_caches, const size_t query_idx_start, const KVCaches& kv_caches,
hwy::ThreadPool& pool, TimingInfo& timing_info) { PerClusterPools& pools, TimingInfo& timing_info) {
constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize; constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights = const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get()); *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
// TODO: remove once all parallel sections support hierarchical parallelism.
hwy::ThreadPool& pool = pools.Inner(0);
const size_t num_queries = prompts.size(); const size_t num_queries = prompts.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.BatchSize()); HWY_ASSERT(num_queries <= activations.x.BatchSize());
@ -984,14 +851,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const size_t prefill_per_query = min_prompt_size - 1; const size_t prefill_per_query = min_prompt_size - 1;
double prefill_start; double prefill_start;
{ {
PrefillState prefill(pool, num_queries); // TODO: move to Gemma, reuse across calls to Generate.
prefill.AllocateActivations<TConfig>(num_queries, PrefillState prefill;
runtime_config.prefill_tbatch_size); prefill.Init<TConfig>(num_queries, runtime_config.prefill_tbatch_size,
pools);
prefill_start = hwy::platform::Now(); prefill_start = hwy::platform::Now();
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start, prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
weights, runtime_config, kv_caches); weights, runtime_config, kv_caches, pools);
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
prefill.ResumeMainSpinning();
} }
size_t interleaved_pos = (pos + prefill_per_query) * num_queries; size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
@ -1051,7 +918,7 @@ template <class TConfig>
void GenerateSingleT(const ByteStorageT& weights_u8, void GenerateSingleT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache, const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) { PerClusterPools& pools, TimingInfo& timing_info) {
const size_t num_queries = 1; const size_t num_queries = 1;
const size_t qbatch_start = 0; const size_t qbatch_start = 0;
@ -1062,14 +929,14 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
const KVCaches kv_caches{&kv_cache, num_queries}; const KVCaches kv_caches{&kv_cache, num_queries};
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos, GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
qbatch_start, kv_caches, pool, timing_info); qbatch_start, kv_caches, pools, timing_info);
} }
template <class TConfig> template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8, void GenerateBatchT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t pos, const MultiplePromptsTokens& prompts, size_t pos,
const KVCaches& kv_caches, hwy::ThreadPool& pool, const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) { TimingInfo& timing_info) {
HWY_ASSERT(prompts.size() == kv_caches.size()); HWY_ASSERT(prompts.size() == kv_caches.size());
// Griffin does not support query batching. // Griffin does not support query batching.
@ -1089,7 +956,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
qbatch_size); qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts, GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
pos, qbatch_start, qbatch_kv, pool, timing_info); pos, qbatch_start, qbatch_kv, pools, timing_info);
} }
} }
@ -1102,18 +969,18 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
void GenerateSingle( // NOLINT(misc-definitions-in-headers) void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) { KVCache& kv_cache, PerClusterPools& pools, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
(weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info); (weights_u8, runtime_config, prompt, pos, kv_cache, pools, timing_info);
} }
void GenerateBatch( // NOLINT(misc-definitions-in-headers) void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8, GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool, size_t pos, const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) { TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info); (weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
} }
#endif // HWY_ONCE #endif // HWY_ONCE

View File

@ -28,23 +28,25 @@
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
namespace gcpp { namespace gcpp {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool) const ModelInfo& info, PerClusterPools& pools)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) { : pools_(pools), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); weights_u8_ =
LoadCompressedWeights(weights, info.model, info.weight, pools_.Inner(0));
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool) PerClusterPools& pools)
: pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) { : pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
HWY_ASSERT(info.weight == Type::kF32); HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ = weights_u8_ = CallForModel<float, AllocateCompressedWeights>(info.model,
CallForModel<float, AllocateCompressedWeights>(info.model, pool); pools_.Inner(0));
} }
Gemma::~Gemma() { Gemma::~Gemma() {
@ -60,12 +62,12 @@ Gemma::~Gemma() {
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \ extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, \ const RuntimeConfig& runtime_config, \
const PromptTokens& prompt, size_t pos, \ const PromptTokens& prompt, size_t pos, \
KVCache& kv_cache, hwy::ThreadPool& pool, \ KVCache& kv_cache, PerClusterPools& pools, \
TimingInfo& timing_info); \ TimingInfo& timing_info); \
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \ extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, \ const RuntimeConfig& runtime_config, \
const MultiplePromptsTokens& prompts, size_t pos, \ const MultiplePromptsTokens& prompts, size_t pos, \
const KVCaches& kv_caches, hwy::ThreadPool& pool, \ const KVCaches& kv_caches, PerClusterPools& pools, \
TimingInfo& timing_info); TimingInfo& timing_info);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
@ -75,9 +77,9 @@ struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8, void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache, const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) const { PerClusterPools& pools, TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache, GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
pool, timing_info); pools, timing_info);
} }
}; };
@ -86,36 +88,36 @@ struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8, void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t pos, const MultiplePromptsTokens& prompts, size_t pos,
const KVCaches& kv_caches, hwy::ThreadPool& pool, const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos, GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
kv_caches, pool, timing_info); kv_caches, pools, timing_info);
} }
}; };
void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t start_pos, const PromptTokens& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info) { KVCache& kv_cache, TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pools_.StartSpinning();
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_, CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
runtime_config, prompt, start_pos, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info); kv_cache, pools_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pools_.StopSpinning();
} }
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, const MultiplePromptsTokens& prompts,
size_t start_pos, const KVCaches& kv_caches, size_t start_pos, const KVCaches& kv_caches,
TimingInfo& timing_info) { TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pools_.StartSpinning();
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_, CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
runtime_config, prompts, start_pos, runtime_config, prompts, start_pos,
kv_caches, pool_, timing_info); kv_caches, pools_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pools_.StopSpinning();
} }
} // namespace gcpp } // namespace gcpp

View File

@ -27,6 +27,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -56,8 +57,8 @@ using SampleFunc = std::function<int(const float*, size_t)>;
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer // - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer
// - pointer to the data array // - pointer to the data array
// - size of the data array // - size of the data array
using LayersOutputFunc = using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
std::function<void(size_t, size_t, const std::string&, int, const float*, size_t)>; int, const float*, size_t)>;
struct RuntimeConfig { struct RuntimeConfig {
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
@ -121,11 +122,11 @@ using KVCaches = hwy::Span<KVCache>;
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
hwy::ThreadPool& pool); PerClusterPools& pools);
// Allocates weights, caller is responsible for filling them. // Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool); PerClusterPools& pools);
~Gemma(); ~Gemma();
const ModelInfo& Info() const { return info_; } const ModelInfo& Info() const { return info_; }
@ -140,7 +141,7 @@ class Gemma {
const KVCaches& kv_caches, TimingInfo& timing_info); const KVCaches& kv_caches, TimingInfo& timing_info);
private: private:
hwy::ThreadPool& pool_; PerClusterPools& pools_;
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header. // Type-erased so that this can be defined in the header.
@ -155,14 +156,6 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos, const ModelInfo& info, size_t pos,
std::string& prompt); std::string& prompt);
// DEPRECATED, call Gemma::Generate directly.
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
TimingInfo& timing_info) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -27,6 +27,7 @@
#include "gemma/gemma.h" // Gemma #include "gemma/gemma.h" // Gemma
#include "util/app.h" #include "util/app.h"
#include "util/args.h" // HasHelp #include "util/args.h" // HasHelp
#include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
@ -75,9 +76,9 @@ std::string GetPrompt(std::istream& input, int verbosity,
} }
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
const InferenceArgs& args, int verbosity, int verbosity, const AcceptFunc& accept_token,
const AcceptFunc& accept_token, std::string& eot_line) { std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // absolute token index over all turns size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -172,13 +173,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
hwy::ThreadPool pool(app.num_threads); // Note that num_threads is an upper bound; we also limit to the number of
// For many-core, pinning workers to cores helps. // detected and enabled cores.
if (app.num_threads > 10) { PerClusterPools pools(app.max_clusters, app.num_threads);
PinWorkersToCores(pool);
}
Gemma model = CreateGemma(loader, pool); Gemma model = CreateGemma(loader, pools);
KVCache kv_cache = KVCache kv_cache =
KVCache::Create(model.Info().model, inference.prefill_tbatch_size); KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
@ -205,11 +204,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\033[2J\033[1;1H" // clear screen std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n"; << kAsciiArtBanner << "\n\n";
ShowConfig(loader, inference, app); ShowConfig(loader, inference, app, pools);
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(), ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(),
app.eot_line); app.eot_line);
} }

View File

@ -23,20 +23,12 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/gemma.h" // For CreateGemma
#include "gemma/gemma.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_IS_ASAN
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#if HWY_OS_LINUX
#include <sched.h>
#endif // HWY_OS_LINUX
namespace gcpp { namespace gcpp {
@ -58,86 +50,14 @@ static inline const char* CompiledConfig() {
} }
} }
static inline std::vector<size_t> LpsToCpus(
const hwy::LogicalProcessorSet& lps) {
std::vector<size_t> cpus;
cpus.reserve(lps.Count());
lps.Foreach([&cpus](size_t lp) { cpus.push_back(lp); });
return cpus;
}
static inline std::vector<size_t> AssignCpusFromTopology(
const hwy::Topology& topology, const size_t num_workers) {
// Assign CPUs to workers 0 to num_workers - 1 based on the topology.
// The assignments are done in a round-robin fashion across all clusters and
// Cores.
// For example, if we have 4 clusters, the assignments will be:
// Thread 0 -> Cluster 0, Core 0
// Thread 1 -> Cluster 1, Core 0
// Thread 2 -> Cluster 2, Core 0
// Thread 3 -> Cluster 3, Core 0
// Thread 4 -> Cluster 0, Core 1
// Thread 5 -> Cluster 1, Core 1
// ... and so on.
//
// This would result in the least amount of sharing of the last-level
// cache slices. All assignments are made from Package 0.
std::vector<std::vector<size_t>> clusters;
for (auto& package : topology.packages) {
for (auto& cluster : package.clusters) {
clusters.push_back(LpsToCpus(cluster.lps));
}
}
std::vector<size_t> assigned_cpus;
assigned_cpus.reserve(num_workers);
for (size_t i = 0; i < num_workers; ++i) {
size_t cluster_index = i % clusters.size();
size_t cpu_index = (i / clusters.size()) % clusters[cluster_index].size();
assigned_cpus.push_back(clusters[cluster_index][cpu_index]);
}
return assigned_cpus;
}
static inline void PinWorkersToCores(hwy::ThreadPool& pool) {
// Use topology to pin workers to cores if available.
hwy::Topology topology;
if (!topology.packages.empty()) {
std::vector<size_t> assigned_cpus =
AssignCpusFromTopology(topology, pool.NumWorkers());
pool.Run(0, pool.NumWorkers(),
[&assigned_cpus](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(assigned_cpus[thread]);
});
} else {
pool.Run(0, pool.NumWorkers(), [](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(thread);
});
}
}
class AppArgs : public ArgsBase<AppArgs> { class AppArgs : public ArgsBase<AppArgs> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = GetSupportedThreadCount();
}
}
public: public:
AppArgs(int argc, char* argv[]) { AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InitAndParse(argc, argv);
ChooseNumThreads();
}
static inline size_t GetSupportedThreadCount() {
return HWY_MIN(hwy::ThreadPool::MaxThreads(), kMaxThreads);
}
Path log; // output Path log; // output
int verbosity; int verbosity;
size_t num_threads; size_t num_threads; // divided among the detected clusters
size_t max_clusters;
std::string eot_line; std::string eot_line;
template <class Visitor> template <class Visitor>
@ -147,11 +67,10 @@ class AppArgs : public ArgsBase<AppArgs> {
"output\n 1 = standard user-facing terminal ui\n 2 = show " "output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.", "developer/debug info).\n Default = 1.",
2); 2);
visitor(num_threads, "num_threads", visitor(num_threads, "num_threads", size_t{0},
kDefaultNumThreads, // see ChooseNumThreads "Maximum number of threads to use; default 0 = unlimited.", 2);
"Number of threads to use.\n Default = Estimate of the " visitor(max_clusters, "max_clusters", size_t{0},
"number of supported concurrent threads.", "Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
2);
visitor( visitor(
eot_line, "eot_line", std::string(""), eot_line, "eot_line", std::string(""),
"End of turn line. " "End of turn line. "
@ -232,14 +151,14 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
}; };
static inline Gemma CreateGemma(const LoaderArgs& loader, static inline Gemma CreateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) { PerClusterPools& pools) {
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool); return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
} }
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader, static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) { PerClusterPools& pools) {
return std::make_unique<Gemma>(loader.tokenizer, loader.weights, return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), pool); loader.Info(), pools);
} }
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {

201
util/threading.h Normal file
View File

@ -0,0 +1,201 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::sort
#include <memory>
#include <vector>
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
namespace gcpp {
// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an
// 'outer' thread pool with one worker per cluster.
//
// Useful for hierarchical parallelism, which makes sense when there are few
// but large tasks which should be parallelized by workers sharing a cache.
// This also implies lower latency for barrier synchronization of those workers.
class PerClusterPools {
using LPS = hwy::LogicalProcessorSet;
static inline std::vector<size_t> CoresInLPS(const LPS& cluster) {
std::vector<size_t> cores;
cores.reserve(cluster.Count());
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
return cores;
}
using CoreBitSets = std::vector<LPS>;
// Returns empty if detection failed.
CoreBitSets DetectCoresPerCluster() {
CoreBitSets clusters;
if (!have_threading_support_) return clusters;
// Which processors are not disabled via OS, taskset, or numactl.
LPS enabled;
// If we don't know, better to abort rather than risk oversubscribing.
if (!GetThreadAffinity(enabled)) return clusters;
hwy::Topology topology;
if (topology.packages.empty()) return clusters;
// Merge all clusters into one set, as a stopgap to emulate gemma-inl's
// prior single pool.
// TODO: remove once MatMul supports hierarchical parallelism.
LPS all;
// For each cluster, add its enabled *cores*.
for (const hwy::Topology::Package& package : topology.packages) {
for (const hwy::Topology::Cluster& cluster : package.clusters) {
cluster.lps.Foreach([&](size_t lp) {
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
all.Set(lp);
}
});
}
/* code to reinstate:
for (const hwy::Topology::Cluster& cluster : package.clusters) {
// Only use enabled *cores*, and only add if not empty.
cluster.lps.Foreach([&](size_t lp) {
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
all.Set(lp);
}
});
if (lps.Any()) clusters.push_back(lps);
}
*/
}
if (all.Any()) clusters.push_back(all);
// Sort by descending number of enabled cores, so that we preferentially
// use the largest clusters.
std::sort(clusters.begin(), clusters.end(),
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
return clusters;
}
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
outer_pool_.SetWaitMode(wait_mode);
for (auto& inner : inner_pools_) {
inner->SetWaitMode(wait_mode);
}
}
// The defaults for `AppArgs` `max_clusters` and `num_threads` are zero, which
// means no limit.
size_t CapIfNonzero(size_t detected, size_t user_max_or_zero) {
if (!have_threading_support_) return 0;
return (user_max_or_zero == 0) ? detected
: HWY_MIN(detected, user_max_or_zero);
}
public:
// PerClusterPools supports spin waits (see StartSpinning below). To prevent
// drastic slowdowns caused by excessive user-specified thread counts, which
// result in threads not running on their own core, we only allow for
// *upper bounds* on the number of clusters and threads. The actual number of
// clusters and threads are still limited by the detected topology.
PerClusterPools(size_t max_clusters, size_t max_threads)
: have_threading_support_(hwy::HaveThreadingSupport()),
cores_per_cluster_(DetectCoresPerCluster()),
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
// Topology detection failed - it currently requires Linux.
if (cores_per_cluster_.empty()) {
// Create a single inner pool with up to TotalLogicalProcessors() / 2
// workers, further limited by `max_threads` if nonzero, and then pin to
// the first N processors, which are typically on the first socket.
const size_t num_threads =
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
if (num_threads > 1) {
inner_pools_.back()->Run(0, num_threads,
[](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(thread);
});
}
return;
}
const size_t max_per_inner = max_threads / outer_pool_.NumWorkers();
for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) {
const size_t num_threads =
CapIfNonzero(cores_per_cluster_[outer].Count(), max_per_inner);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
}
// For each inner pool, pin their threads AND the associated outer thread
// (the one calling inner.Run()) to the enabled cores in the cluster.
outer_pool_.Run(
0, outer_pool_.NumWorkers(),
[this](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
hwy::ThreadPool& inner = *inner_pools_[outer];
const std::vector<size_t> cores =
CoresInLPS(cores_per_cluster_[outer]);
// May have been capped by max_threads.
HWY_ASSERT(inner.NumWorkers() <= cores.size());
inner.Run(0, inner.NumWorkers(),
[&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
}
// Spinning reduces the latency of barrier synchronization, but wastes lots of
// energy for long waits, so only do it during generation. This might also be
// unsafe in virtualized environments because we require threads to be running
// on their own core and thus responsive to the barrier synchronization.
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
// Bitset of cores, one per cluster, or empty if detection failed. Useful for
// displaying the topology.
const CoreBitSets& CoresPerCluster() const { return cores_per_cluster_; }
hwy::ThreadPool& Outer() { return outer_pool_; }
hwy::ThreadPool& Inner(size_t outer) {
HWY_ASSERT(outer < Outer().NumWorkers());
return *inner_pools_[outer];
}
private:
bool have_threading_support_;
CoreBitSets cores_per_cluster_;
hwy::ThreadPool outer_pool_;
// hwy::ThreadPool is unfortunately not marked as movable, so we have to use
// unique_ptr.
std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_