From 5e433e774a0c890b44f1dc2684b08b8be4f25601 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 5 Aug 2024 18:49:16 -0700 Subject: [PATCH] 1.1x prefill speedup, revamp threading in preparation for hierarchical parallelism. Limit thread counts to detected. Add max_clusters arg. Update detection logic to check for smt0 - previously we pinned to some siblings. PiperOrigin-RevId: 659755311 --- BUILD.bazel | 16 ++- backprop/optimize_test.cc | 6 +- evals/benchmark_helper.cc | 47 ++++---- evals/benchmark_helper.h | 7 +- examples/hello_world/BUILD | 1 + examples/hello_world/run.cc | 6 +- gemma/gemma-inl.h | 209 +++++++----------------------------- gemma/gemma.cc | 40 +++---- gemma/gemma.h | 19 ++-- gemma/run.cc | 21 ++-- util/app.h | 107 +++--------------- util/threading.h | 201 ++++++++++++++++++++++++++++++++++ 12 files changed, 338 insertions(+), 342 deletions(-) create mode 100644 util/threading.h diff --git a/BUILD.bazel b/BUILD.bazel index 4dd4697..6620122 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -148,6 +148,16 @@ cc_library( ], ) +cc_library( + name = "threading", + hdrs = ["util/threading.h"], + deps = [ + "@hwy//:hwy", + "@hwy//:thread_pool", + "@hwy//:topology", + ], +) + cc_library( name = "gemma_lib", srcs = [ @@ -192,6 +202,7 @@ cc_library( ":tokenizer", ":kv_cache", ":weights", + ":threading", "//compression:io", "@hwy//:hwy", "@hwy//:bit_set", @@ -248,6 +259,7 @@ cc_library( ":cross_entropy", ":gemma_lib", ":kv_cache", + ":threading", # Placeholder for internal dep, do not remove., "@benchmark//:benchmark", "//compression:compress", @@ -286,10 +298,9 @@ cc_binary( ":benchmark_helper", ":common", ":gemma_lib", + ":threading", # Placeholder for internal dep, do not remove., - "//compression:compress", "@hwy//:hwy", - "@hwy//:nanobenchmark", "@hwy//:profiler", "@hwy//:thread_pool", ], @@ -486,6 +497,7 @@ cc_test( ":optimizer", ":prompt", ":sampler", + ":threading", ":weights", "@googletest//:gtest_main", "@hwy//:thread_pool", diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 8df94e9..6a4f030 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -29,12 +29,14 @@ #include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/weights.h" +#include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { TEST(OptimizeTest, GradientDescent) { - hwy::ThreadPool pool(0); + PerClusterPools pools(1, 1); + hwy::ThreadPool& pool = pools.Inner(0); std::mt19937 gen(42); const ModelInfo info = { @@ -54,7 +56,7 @@ TEST(OptimizeTest, GradientDescent) { CallForModelAndWeight(info.model, info.weight); KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16); - Gemma gemma(GemmaTokenizer(), info, pool); + Gemma gemma(GemmaTokenizer(), info, pools); const auto generate = [&](const std::vector& prompt) { std::vector reply; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 4111bd9..04a27ed 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -25,7 +25,6 @@ #include #include #include -#include // NOLINT #include // std::pair #include @@ -37,6 +36,7 @@ #include "gemma/kv_cache.h" #include "util/app.h" #include "util/args.h" +#include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -61,12 +61,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, : loader_(loader), inference_args_(inference), app_(app), - pool_(app_.num_threads) { - // For many-core, pinning workers to cores helps. - if (app_.num_threads > 10) { - PinWorkersToCores(pool_); - } - + pools_(app_.max_clusters, app_.num_threads) { AbortIfInvalidArgs(inference_args_); if (const char* err = loader_.Validate()) { @@ -74,7 +69,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, fprintf(stderr, "Skipping model load because: %s\n", err); } else { fprintf(stderr, "Loading model...\n"); - model_ = AllocateGemma(loader_, pool_); + model_ = AllocateGemma(loader_, pools_); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.resize(1); @@ -234,7 +229,8 @@ void LogSpeedStats(double time_start, size_t total_tokens) { << " [" << tok_sec << " tokens / sec" << "]\n"; } -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, + PerClusterPools& pools) { loader.Print(app.verbosity); inference.Print(app.verbosity); app.Print(app.verbosity); @@ -242,22 +238,23 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { if (app.verbosity >= 2) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT - // TODO: replace hardware_concurrency with detected topology. - std::cout << "Date & Time : " << dt - << "Hardware concurrency : " - << std::thread::hardware_concurrency() << "\n" - << "Instruction set : " - << hwy::TargetName(hwy::DispatchedTarget()) << " (" - << hwy::VectorBytes() * 8 << " bits)" << "\n"; - char cpu100[100]; - if (hwy::platform::GetCpuString(cpu100)) { - std::cout << "CPU : " << cpu100 << "\n"; - } - std::cout << "Compiled config : " << CompiledConfig() << "\n" - << "Weight Type : " - << StringFromType(loader.Info().weight) << "\n" - << "EmbedderInput Type : " - << TypeName(EmbedderInputT()) << "\n"; + char cpu100[100] = "unknown"; + (void)hwy::platform::GetCpuString(cpu100); + + fprintf(stderr, + "Date & Time : %s" // dt includes \n + "CPU : %s\n" + "CPU topology : %zux%zu, using %zux%zu\n" + "Instruction set : %s (%zu bits)\n" + "Compiled config : %s\n" + "Weight Type : %s\n" + "EmbedderInput Type : %s\n", + dt, cpu100, pools.CoresPerCluster().size(), + pools.CoresPerCluster()[0].Count(), pools.Outer().NumWorkers(), + pools.Inner(0).NumWorkers(), + hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8, + CompiledConfig(), StringFromType(loader.Info().weight), + TypeName(EmbedderInputT())); } } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 21b0d2f..b9302d5 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -26,8 +26,8 @@ #include "gemma/gemma.h" #include "util/app.h" +#include "util/threading.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -98,7 +98,7 @@ class GemmaEnv { // Controls overall behavior of the app. AppArgs app_; // Thread pool for running inference. - hwy::ThreadPool pool_; + PerClusterPools pools_; // Random number generator. std::mt19937 gen_; // The model to run inference on. @@ -111,7 +111,8 @@ class GemmaEnv { // Logs the inference speed in tokens/sec. void LogSpeedStats(double time_start, size_t total_tokens); -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); +void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, + PerClusterPools& pools); void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); } // namespace gcpp diff --git a/examples/hello_world/BUILD b/examples/hello_world/BUILD index 98fe5fd..ca7f426 100644 --- a/examples/hello_world/BUILD +++ b/examples/hello_world/BUILD @@ -15,6 +15,7 @@ cc_binary( "//:args", "//:common", "//:gemma_lib", + "//:threading", "//:tokenizer", "@hwy//:hwy", "@hwy//:thread_pool", diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index bfef76f..00425a3 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -26,12 +26,14 @@ #include "gemma/tokenizer.h" #include "util/app.h" // LoaderArgs #include "util/args.h" +#include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); if (gcpp::HasHelp(argc, argv)) { loader.Help(); return 0; @@ -41,8 +43,8 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount()); - gcpp::Gemma model = gcpp::CreateGemma(loader, pool); + gcpp::PerClusterPools pools(app.max_clusters, app.num_threads); + gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); size_t pos = 0; // KV Cache position diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 633d898..4d96c2d 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -42,11 +42,11 @@ #include "ops/matmul-inl.h" #include "ops/matvec-inl.h" #include "ops/ops-inl.h" +#include "util/threading.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/contrib/thread_pool/topology.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" @@ -575,143 +575,28 @@ HWY_NOINLINE void TransformerLayer( // NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt // to their num_queries. class PrefillState { - // TODO: move helper functions, also those in app.h, to a threading header - using LPS = hwy::LogicalProcessorSet; - LPS Intersection(const LPS& big_set, const LPS& small_set) { - LPS both_set; - // Reduce expected work by iterating over the smaller set. - small_set.Foreach([&big_set, &both_set](size_t idx) { - if (big_set.Get(idx)) both_set.Set(idx); - }); - return both_set; - } - - std::vector CoresInLPS(const LPS& cluster) { - std::vector cores; - cores.reserve(cluster.Count()); - cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); }); - return cores; - } - - // For each cluster (shared L3 cache), a bitset of cores. - using CoresPerCluster = std::vector; - - // Returns empty if detection failed. - CoresPerCluster DetectClusters() { - CoresPerCluster clusters; - // Which processors are not disabled via OS, taskset, or numactl. - LPS enabled; - // If we don't know, better to use just a single inner pool rather than risk - // oversubscribing to enabled cores. - if (!GetThreadAffinity(enabled)) return clusters; - - hwy::Topology topology; - if (topology.packages.empty()) return clusters; - - // For each cluster = outer, the cores that will be used for an inner pool. - CoresPerCluster inner_lps; - for (const hwy::Topology::Package& package : topology.packages) { - for (const hwy::Topology::Cluster& cluster : package.clusters) { - // Only use enabled cores, and only add if not empty. - const LPS lps = Intersection(enabled, cluster.lps); - if (lps.Any()) clusters.push_back(lps); - } - } - - // Sort by descending number of enabled cores, so that we preferentially - // use the largest clusters. - std::sort(clusters.begin(), clusters.end(), - [](const LPS& a, const LPS& b) { return a.Count() > b.Count(); }); - - return clusters; - } - - // Returns false if the main pool should be reused instead. - bool AssignInnerPoolsToClusters(const size_t num_queries) { - HWY_ASSERT(num_queries != 0); - - CoresPerCluster inner_lps = DetectClusters(); - // If we have more outer workers than queries, discard the excess. - if (inner_lps.size() > num_queries) inner_lps.resize(num_queries); - // If we're not going to create multiple pools, avoid the overhead of - // re-pinning (60 ms) and reuse the main pool. - if (inner_lps.size() <= 1) return false; - - // Before creating new threads, stop the old ones from spinning. Caller is - // responsible for undoing this by calling `ResumeMainSpinning`. - main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock); - - outer_pool_ = std::make_unique(inner_lps.size()); - outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin); - - HWY_ASSERT(inner_pools_.empty()); - for (const LPS& inner : inner_lps) { - inner_pools_.push_back(new hwy::ThreadPool(inner.Count())); - inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin); - } - - // For each inner pool, pin their threads AND the associated outer thread - // to the enabled cores in the cluster. - outer_pool_->Run( - 0, inner_lps.size(), - [this, &inner_lps](uint64_t outer, size_t outer_thread) { - HWY_ASSERT(outer == outer_thread); // each outer has one task - const std::vector cores = CoresInLPS(inner_lps[outer]); - - inner_pools_[outer]->Run( - 0, cores.size(), [&cores](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each inner has one task - hwy::PinThreadToLogicalProcessor(cores[task]); - }); - }); - - return true; - } - - void ReuseMainPoolAsInner() { - // Still allocate an empty pool to simplify Prefill(). - outer_pool_ = std::make_unique(1); - - HWY_ASSERT(inner_pools_.empty()); - inner_pools_.push_back(main_pool_); - } - public: - // Creates pools. AllocateActivations must still be called separately; it has - // a template argument. - PrefillState(hwy::ThreadPool& main_pool, size_t num_queries) - : main_pool_(&main_pool) { - PROFILER_ZONE("Init.Prefill.Ctor"); - if (!AssignInnerPoolsToClusters(num_queries)) { - ReuseMainPoolAsInner(); - } - } - - ~PrefillState() { - for (hwy::ThreadPool* p : inner_pools_) { - if (p != main_pool_) delete p; - } - } - // `tbatch_size` is the number of tokens from one query to prefill at a time. template - void AllocateActivations(size_t num_queries, size_t tbatch_size) { - PROFILER_ZONE("Init.Prefill.AllocateActivations"); - - const size_t outer_workers = outer_pool_->NumWorkers(); - HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty. - + void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) { + PROFILER_ZONE("Init.Prefill"); + HWY_ASSERT(num_queries != 0); HWY_ASSERT(activations_.empty()); // only call once. - activations_.resize(outer_workers); - if (outer_workers == 1) { + // Allocate one activation per query, not outer worker, because the common + // case is a single query. If we allocate the lesser of the two, it is + // unclear how to choose an unused activation in Prefill. + activations_.resize(num_queries); + + if (num_queries == 1) { activations_[0].Allocate(tbatch_size); } else { - // Allocating in parallel can save 30 ms. - main_pool_->Run(0, outer_workers, - [this, tbatch_size](uint64_t task, size_t /*thread*/) { - activations_[task].Allocate(tbatch_size); - }); + // Allocating in parallel can save 30 ms. We might have more workers than + // queries/tasks, so do not check the `thread` argument. + pools.Outer().Run(0, num_queries, + [this, tbatch_size](uint64_t qi, size_t /*thread*/) { + activations_[qi].Allocate(tbatch_size); + }); } } @@ -721,7 +606,7 @@ class PrefillState { const size_t query_idx_start, const CompressedWeights& weights, const RuntimeConfig& runtime_config, - const KVCaches& kv_caches) { + const KVCaches& kv_caches, PerClusterPools& pools) { PROFILER_ZONE("Gen.Prefill"); const size_t num_queries = prompts.size(); HWY_ASSERT(kv_caches.size() == num_queries); @@ -729,10 +614,10 @@ class PrefillState { // For each query (parallel): an outer worker processes all its tokens. // `qi` is relative to the batch, not the global query index. - outer_pool_->Run( + pools.Outer().Run( 0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR { - Activations& activations = activations_[qthread]; - hwy::ThreadPool& inner_pool = *inner_pools_[qthread]; + Activations& activations = activations_[qi]; + hwy::ThreadPool& inner_pool = pools.Inner(qthread); // Single query at a time, so pass a slice of the KV cache because // GemmaAttention will only access the first. @@ -768,29 +653,8 @@ class PrefillState { }); } - // Stops spinning in our pools and resume spinning in main_pool_. - void ResumeMainSpinning() { - // If we didn't create a new inner pool, we didn't stop spinning on the - // main pool, so nothing to do here. - if (inner_pools_[0] == main_pool_) return; - - for (hwy::ThreadPool* p : inner_pools_) { - p->SetWaitMode(hwy::PoolWaitMode::kBlock); - } - outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock); - main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin); - } - private: - hwy::ThreadPool* main_pool_; - std::unique_ptr outer_pool_; // always allocated - // Holds a single pointer equal to main_pool_, or new allocations; in either - // case, size() is equal to outer_pool_->NumWorkers(). The first case avoids - // allocation overhead for the common case of a single query. - std::vector inner_pools_; - - // size() == outer_pool_->NumWorkers(); filled by AllocateActivations. - std::vector activations_; + std::vector activations_; // One per query, filled by Init. }; // `tokens` is length `num_tokens * num_queries`. In autoregressive decode, @@ -945,12 +809,15 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, const size_t pos, const size_t query_idx_start, const KVCaches& kv_caches, - hwy::ThreadPool& pool, TimingInfo& timing_info) { + PerClusterPools& pools, TimingInfo& timing_info) { constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kVocabSize = TConfig::kVocabSize; const CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); + // TODO: remove once all parallel sections support hierarchical parallelism. + hwy::ThreadPool& pool = pools.Inner(0); + const size_t num_queries = prompts.size(); HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. HWY_ASSERT(num_queries <= activations.x.BatchSize()); @@ -984,14 +851,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, const size_t prefill_per_query = min_prompt_size - 1; double prefill_start; { - PrefillState prefill(pool, num_queries); - prefill.AllocateActivations(num_queries, - runtime_config.prefill_tbatch_size); + // TODO: move to Gemma, reuse across calls to Generate. + PrefillState prefill; + prefill.Init(num_queries, runtime_config.prefill_tbatch_size, + pools); prefill_start = hwy::platform::Now(); prefill.Prefill(prompts, prefill_per_query, pos, query_idx_start, - weights, runtime_config, kv_caches); + weights, runtime_config, kv_caches, pools); timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); - prefill.ResumeMainSpinning(); } size_t interleaved_pos = (pos + prefill_per_query) * num_queries; @@ -1051,7 +918,7 @@ template void GenerateSingleT(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, TimingInfo& timing_info) { + PerClusterPools& pools, TimingInfo& timing_info) { const size_t num_queries = 1; const size_t qbatch_start = 0; @@ -1062,14 +929,14 @@ void GenerateSingleT(const ByteStorageT& weights_u8, const KVCaches kv_caches{&kv_cache, num_queries}; GenerateT(weights_u8, activations, runtime_config, prompts, pos, - qbatch_start, kv_caches, pool, timing_info); + qbatch_start, kv_caches, pools, timing_info); } template void GenerateBatchT(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, size_t pos, - const KVCaches& kv_caches, hwy::ThreadPool& pool, + const KVCaches& kv_caches, PerClusterPools& pools, TimingInfo& timing_info) { HWY_ASSERT(prompts.size() == kv_caches.size()); // Griffin does not support query batching. @@ -1089,7 +956,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8, qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); GenerateT(weights_u8, activations, runtime_config, qbatch_prompts, - pos, qbatch_start, qbatch_kv, pool, timing_info); + pos, qbatch_start, qbatch_kv, pools, timing_info); } } @@ -1102,18 +969,18 @@ void GenerateBatchT(const ByteStorageT& weights_u8, void GenerateSingle( // NOLINT(misc-definitions-in-headers) GEMMA_CONFIG, const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, - KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) { + KVCache& kv_cache, PerClusterPools& pools, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info); + (weights_u8, runtime_config, prompt, pos, kv_cache, pools, timing_info); } void GenerateBatch( // NOLINT(misc-definitions-in-headers) GEMMA_CONFIG, const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, - size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool, + size_t pos, const KVCaches& kv_caches, PerClusterPools& pools, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) - (weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info); + (weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info); } #endif // HWY_ONCE diff --git a/gemma/gemma.cc b/gemma/gemma.cc index d3b9b6a..ce5f07e 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -28,23 +28,25 @@ #include "compression/io.h" // Path #include "gemma/common.h" #include "gemma/weights.h" +#include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" namespace gcpp { Gemma::Gemma(const Path& tokenizer_path, const Path& weights, - const ModelInfo& info, hwy::ThreadPool& pool) - : pool_(pool), tokenizer_(tokenizer_path), info_(info) { - weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); + const ModelInfo& info, PerClusterPools& pools) + : pools_(pools), tokenizer_(tokenizer_path), info_(info) { + weights_u8_ = + LoadCompressedWeights(weights, info.model, info.weight, pools_.Inner(0)); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, - hwy::ThreadPool& pool) - : pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) { + PerClusterPools& pools) + : pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) { HWY_ASSERT(info.weight == Type::kF32); - weights_u8_ = - CallForModel(info.model, pool); + weights_u8_ = CallForModel(info.model, + pools_.Inner(0)); } Gemma::~Gemma() { @@ -60,12 +62,12 @@ Gemma::~Gemma() { extern void GenerateSingle(CONFIGT, const ByteStorageT& weights_u8, \ const RuntimeConfig& runtime_config, \ const PromptTokens& prompt, size_t pos, \ - KVCache& kv_cache, hwy::ThreadPool& pool, \ + KVCache& kv_cache, PerClusterPools& pools, \ TimingInfo& timing_info); \ extern void GenerateBatch(CONFIGT, const ByteStorageT& weights_u8, \ const RuntimeConfig& runtime_config, \ const MultiplePromptsTokens& prompts, size_t pos, \ - const KVCaches& kv_caches, hwy::ThreadPool& pool, \ + const KVCaches& kv_caches, PerClusterPools& pools, \ TimingInfo& timing_info); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); @@ -75,9 +77,9 @@ struct GenerateSingleT { void operator()(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, KVCache& kv_cache, - hwy::ThreadPool& pool, TimingInfo& timing_info) const { + PerClusterPools& pools, TimingInfo& timing_info) const { GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache, - pool, timing_info); + pools, timing_info); } }; @@ -86,36 +88,36 @@ struct GenerateBatchT { void operator()(const ByteStorageT& weights_u8, const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, size_t pos, - const KVCaches& kv_caches, hwy::ThreadPool& pool, + const KVCaches& kv_caches, PerClusterPools& pools, TimingInfo& timing_info) const { GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos, - kv_caches, pool, timing_info); + kv_caches, pools, timing_info); } }; void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info) { - pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); + pools_.StartSpinning(); CallForModelAndWeight(info_.model, info_.weight, weights_u8_, runtime_config, prompt, start_pos, - kv_cache, pool_, timing_info); + kv_cache, pools_, timing_info); - pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); + pools_.StopSpinning(); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts, size_t start_pos, const KVCaches& kv_caches, TimingInfo& timing_info) { - pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); + pools_.StartSpinning(); CallForModelAndWeight(info_.model, info_.weight, weights_u8_, runtime_config, prompts, start_pos, - kv_caches, pool_, timing_info); + kv_caches, pools_, timing_info); - pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); + pools_.StopSpinning(); } } // namespace gcpp diff --git a/gemma/gemma.h b/gemma/gemma.h index 081e783..da4028d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -27,6 +27,7 @@ #include "gemma/common.h" #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" +#include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" // IWYU pragma: end_exports @@ -56,8 +57,8 @@ using SampleFunc = std::function; // - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer // - pointer to the data array // - size of the data array -using LayersOutputFunc = - std::function; +using LayersOutputFunc = std::function; struct RuntimeConfig { bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { @@ -121,11 +122,11 @@ using KVCaches = hwy::Span; class Gemma { public: Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, - hwy::ThreadPool& pool); + PerClusterPools& pools); // Allocates weights, caller is responsible for filling them. Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, - hwy::ThreadPool& pool); + PerClusterPools& pools); ~Gemma(); const ModelInfo& Info() const { return info_; } @@ -140,7 +141,7 @@ class Gemma { const KVCaches& kv_caches, TimingInfo& timing_info); private: - hwy::ThreadPool& pool_; + PerClusterPools& pools_; GemmaTokenizer tokenizer_; // Type-erased so that this can be defined in the header. @@ -155,14 +156,6 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, const ModelInfo& info, size_t pos, std::string& prompt); -// DEPRECATED, call Gemma::Generate directly. -HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, - const std::vector& prompt, size_t start_pos, - KVCache& kv_cache, hwy::ThreadPool& /*pool*/, - TimingInfo& timing_info) { - gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ diff --git a/gemma/run.cc b/gemma/run.cc index ce4902f..33a9e2b 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -27,6 +27,7 @@ #include "gemma/gemma.h" // Gemma #include "util/app.h" #include "util/args.h" // HasHelp +#include "util/threading.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -75,9 +76,9 @@ std::string GetPrompt(std::istream& input, int verbosity, } // The main Read-Eval-Print Loop. -void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, - const InferenceArgs& args, int verbosity, - const AcceptFunc& accept_token, std::string& eot_line) { +void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, + int verbosity, const AcceptFunc& accept_token, + std::string& eot_line) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -172,13 +173,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { PROFILER_ZONE("Run.misc"); - hwy::ThreadPool pool(app.num_threads); - // For many-core, pinning workers to cores helps. - if (app.num_threads > 10) { - PinWorkersToCores(pool); - } + // Note that num_threads is an upper bound; we also limit to the number of + // detected and enabled cores. + PerClusterPools pools(app.max_clusters, app.num_threads); - Gemma model = CreateGemma(loader, pool); + Gemma model = CreateGemma(loader, pools); KVCache kv_cache = KVCache::Create(model.Info().model, inference.prefill_tbatch_size); @@ -205,11 +204,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, inference, app); + ShowConfig(loader, inference, app, pools); std::cout << "\n" << instructions << "\n"; } - ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(), + ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(), app.eot_line); } diff --git a/util/app.h b/util/app.h index 160d120..ef4efb8 100644 --- a/util/app.h +++ b/util/app.h @@ -23,20 +23,12 @@ #include #include -#include #include "compression/io.h" // Path #include "gemma/common.h" -#include "gemma/configs.h" -#include "gemma/gemma.h" +#include "gemma/gemma.h" // For CreateGemma #include "util/args.h" -#include "hwy/base.h" // HWY_ASSERT -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/contrib/thread_pool/topology.h" - -#if HWY_OS_LINUX -#include -#endif // HWY_OS_LINUX +#include "hwy/base.h" // HWY_IS_ASAN namespace gcpp { @@ -58,86 +50,14 @@ static inline const char* CompiledConfig() { } } -static inline std::vector LpsToCpus( - const hwy::LogicalProcessorSet& lps) { - std::vector cpus; - cpus.reserve(lps.Count()); - lps.Foreach([&cpus](size_t lp) { cpus.push_back(lp); }); - return cpus; -} - -static inline std::vector AssignCpusFromTopology( - const hwy::Topology& topology, const size_t num_workers) { - // Assign CPUs to workers 0 to num_workers - 1 based on the topology. - // The assignments are done in a round-robin fashion across all clusters and - // Cores. - // For example, if we have 4 clusters, the assignments will be: - // Thread 0 -> Cluster 0, Core 0 - // Thread 1 -> Cluster 1, Core 0 - // Thread 2 -> Cluster 2, Core 0 - // Thread 3 -> Cluster 3, Core 0 - // Thread 4 -> Cluster 0, Core 1 - // Thread 5 -> Cluster 1, Core 1 - // ... and so on. - // - // This would result in the least amount of sharing of the last-level - // cache slices. All assignments are made from Package 0. - std::vector> clusters; - for (auto& package : topology.packages) { - for (auto& cluster : package.clusters) { - clusters.push_back(LpsToCpus(cluster.lps)); - } - } - std::vector assigned_cpus; - assigned_cpus.reserve(num_workers); - for (size_t i = 0; i < num_workers; ++i) { - size_t cluster_index = i % clusters.size(); - size_t cpu_index = (i / clusters.size()) % clusters[cluster_index].size(); - assigned_cpus.push_back(clusters[cluster_index][cpu_index]); - } - return assigned_cpus; -} - -static inline void PinWorkersToCores(hwy::ThreadPool& pool) { - // Use topology to pin workers to cores if available. - hwy::Topology topology; - if (!topology.packages.empty()) { - std::vector assigned_cpus = - AssignCpusFromTopology(topology, pool.NumWorkers()); - pool.Run(0, pool.NumWorkers(), - [&assigned_cpus](uint64_t /*task*/, size_t thread) { - hwy::PinThreadToLogicalProcessor(assigned_cpus[thread]); - }); - } else { - pool.Run(0, pool.NumWorkers(), [](uint64_t /*task*/, size_t thread) { - hwy::PinThreadToLogicalProcessor(thread); - }); - } -} - class AppArgs : public ArgsBase { - static constexpr size_t kDefaultNumThreads = ~size_t{0}; - - void ChooseNumThreads() { - if (num_threads == kDefaultNumThreads) { - // This is a rough heuristic, replace with something better in the future. - num_threads = GetSupportedThreadCount(); - } - } - public: - AppArgs(int argc, char* argv[]) { - InitAndParse(argc, argv); - ChooseNumThreads(); - } - - static inline size_t GetSupportedThreadCount() { - return HWY_MIN(hwy::ThreadPool::MaxThreads(), kMaxThreads); - } + AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } Path log; // output int verbosity; - size_t num_threads; + size_t num_threads; // divided among the detected clusters + size_t max_clusters; std::string eot_line; template @@ -147,11 +67,10 @@ class AppArgs : public ArgsBase { "output\n 1 = standard user-facing terminal ui\n 2 = show " "developer/debug info).\n Default = 1.", 2); - visitor(num_threads, "num_threads", - kDefaultNumThreads, // see ChooseNumThreads - "Number of threads to use.\n Default = Estimate of the " - "number of supported concurrent threads.", - 2); + visitor(num_threads, "num_threads", size_t{0}, + "Maximum number of threads to use; default 0 = unlimited.", 2); + visitor(max_clusters, "max_clusters", size_t{0}, + "Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2); visitor( eot_line, "eot_line", std::string(""), "End of turn line. " @@ -232,14 +151,14 @@ struct LoaderArgs : public ArgsBase { }; static inline Gemma CreateGemma(const LoaderArgs& loader, - hwy::ThreadPool& pool) { - return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool); + PerClusterPools& pools) { + return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools); } static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, - hwy::ThreadPool& pool) { + PerClusterPools& pools) { return std::make_unique(loader.tokenizer, loader.weights, - loader.Info(), pool); + loader.Info(), pools); } struct InferenceArgs : public ArgsBase { diff --git a/util/threading.h b/util/threading.h new file mode 100644 index 0000000..ec2a575 --- /dev/null +++ b/util/threading.h @@ -0,0 +1,201 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Shared between various frontends. + +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ + +#include +#include + +#include // std::sort +#include +#include + +#include "hwy/base.h" // HWY_ASSERT +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/contrib/thread_pool/topology.h" + +namespace gcpp { + +// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an +// 'outer' thread pool with one worker per cluster. +// +// Useful for hierarchical parallelism, which makes sense when there are few +// but large tasks which should be parallelized by workers sharing a cache. +// This also implies lower latency for barrier synchronization of those workers. +class PerClusterPools { + using LPS = hwy::LogicalProcessorSet; + + static inline std::vector CoresInLPS(const LPS& cluster) { + std::vector cores; + cores.reserve(cluster.Count()); + cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); }); + return cores; + } + + using CoreBitSets = std::vector; + + // Returns empty if detection failed. + CoreBitSets DetectCoresPerCluster() { + CoreBitSets clusters; + if (!have_threading_support_) return clusters; + + // Which processors are not disabled via OS, taskset, or numactl. + LPS enabled; + // If we don't know, better to abort rather than risk oversubscribing. + if (!GetThreadAffinity(enabled)) return clusters; + + hwy::Topology topology; + if (topology.packages.empty()) return clusters; + + // Merge all clusters into one set, as a stopgap to emulate gemma-inl's + // prior single pool. + // TODO: remove once MatMul supports hierarchical parallelism. + LPS all; + + // For each cluster, add its enabled *cores*. + for (const hwy::Topology::Package& package : topology.packages) { + for (const hwy::Topology::Cluster& cluster : package.clusters) { + cluster.lps.Foreach([&](size_t lp) { + if (enabled.Get(lp) && topology.lps[lp].smt == 0) { + all.Set(lp); + } + }); + } + + /* code to reinstate: + for (const hwy::Topology::Cluster& cluster : package.clusters) { + // Only use enabled *cores*, and only add if not empty. + cluster.lps.Foreach([&](size_t lp) { + if (enabled.Get(lp) && topology.lps[lp].smt == 0) { + all.Set(lp); + } + }); + if (lps.Any()) clusters.push_back(lps); + } + */ + } + if (all.Any()) clusters.push_back(all); + + // Sort by descending number of enabled cores, so that we preferentially + // use the largest clusters. + std::sort(clusters.begin(), clusters.end(), + [](const LPS& a, const LPS& b) { return a.Count() > b.Count(); }); + + return clusters; + } + + void SetWaitMode(hwy::PoolWaitMode wait_mode) { + outer_pool_.SetWaitMode(wait_mode); + for (auto& inner : inner_pools_) { + inner->SetWaitMode(wait_mode); + } + } + + // The defaults for `AppArgs` `max_clusters` and `num_threads` are zero, which + // means no limit. + size_t CapIfNonzero(size_t detected, size_t user_max_or_zero) { + if (!have_threading_support_) return 0; + return (user_max_or_zero == 0) ? detected + : HWY_MIN(detected, user_max_or_zero); + } + + public: + // PerClusterPools supports spin waits (see StartSpinning below). To prevent + // drastic slowdowns caused by excessive user-specified thread counts, which + // result in threads not running on their own core, we only allow for + // *upper bounds* on the number of clusters and threads. The actual number of + // clusters and threads are still limited by the detected topology. + PerClusterPools(size_t max_clusters, size_t max_threads) + : have_threading_support_(hwy::HaveThreadingSupport()), + cores_per_cluster_(DetectCoresPerCluster()), + outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) { + // Topology detection failed - it currently requires Linux. + if (cores_per_cluster_.empty()) { + // Create a single inner pool with up to TotalLogicalProcessors() / 2 + // workers, further limited by `max_threads` if nonzero, and then pin to + // the first N processors, which are typically on the first socket. + const size_t num_threads = + CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads); + fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads); + inner_pools_.push_back(std::make_unique(num_threads)); + if (num_threads > 1) { + inner_pools_.back()->Run(0, num_threads, + [](uint64_t /*task*/, size_t thread) { + hwy::PinThreadToLogicalProcessor(thread); + }); + } + return; + } + + const size_t max_per_inner = max_threads / outer_pool_.NumWorkers(); + for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) { + const size_t num_threads = + CapIfNonzero(cores_per_cluster_[outer].Count(), max_per_inner); + inner_pools_.push_back(std::make_unique(num_threads)); + } + + // For each inner pool, pin their threads AND the associated outer thread + // (the one calling inner.Run()) to the enabled cores in the cluster. + outer_pool_.Run( + 0, outer_pool_.NumWorkers(), + [this](uint64_t outer, size_t outer_thread) { + HWY_ASSERT(outer == outer_thread); // each outer has one task + hwy::ThreadPool& inner = *inner_pools_[outer]; + + const std::vector cores = + CoresInLPS(cores_per_cluster_[outer]); + // May have been capped by max_threads. + HWY_ASSERT(inner.NumWorkers() <= cores.size()); + + inner.Run(0, inner.NumWorkers(), + [&cores](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each inner has one task + hwy::PinThreadToLogicalProcessor(cores[task]); + }); + }); + } + + // Spinning reduces the latency of barrier synchronization, but wastes lots of + // energy for long waits, so only do it during generation. This might also be + // unsafe in virtualized environments because we require threads to be running + // on their own core and thus responsive to the barrier synchronization. + void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); } + void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); } + + // Bitset of cores, one per cluster, or empty if detection failed. Useful for + // displaying the topology. + const CoreBitSets& CoresPerCluster() const { return cores_per_cluster_; } + + hwy::ThreadPool& Outer() { return outer_pool_; } + hwy::ThreadPool& Inner(size_t outer) { + HWY_ASSERT(outer < Outer().NumWorkers()); + return *inner_pools_[outer]; + } + + private: + bool have_threading_support_; + CoreBitSets cores_per_cluster_; + hwy::ThreadPool outer_pool_; + // hwy::ThreadPool is unfortunately not marked as movable, so we have to use + // unique_ptr. + std::vector> inner_pools_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_