mirror of https://github.com/google/gemma.cpp.git
1.1x prefill speedup, revamp threading in preparation for hierarchical parallelism.
Limit thread counts to detected. Add max_clusters arg. Update detection logic to check for smt0 - previously we pinned to some siblings. PiperOrigin-RevId: 659755311
This commit is contained in:
parent
1617e1a33d
commit
5e433e774a
16
BUILD.bazel
16
BUILD.bazel
|
|
@ -148,6 +148,16 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "threading",
|
||||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
|
|
@ -192,6 +202,7 @@ cc_library(
|
|||
":tokenizer",
|
||||
":kv_cache",
|
||||
":weights",
|
||||
":threading",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:bit_set",
|
||||
|
|
@ -248,6 +259,7 @@ cc_library(
|
|||
":cross_entropy",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"@benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
|
|
@ -286,10 +298,9 @@ cc_binary(
|
|||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
@ -486,6 +497,7 @@ cc_test(
|
|||
":optimizer",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:thread_pool",
|
||||
|
|
|
|||
|
|
@ -29,12 +29,14 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
hwy::ThreadPool pool(0);
|
||||
PerClusterPools pools(1, 1);
|
||||
hwy::ThreadPool& pool = pools.Inner(0);
|
||||
std::mt19937 gen(42);
|
||||
|
||||
const ModelInfo info = {
|
||||
|
|
@ -54,7 +56,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, pool);
|
||||
Gemma gemma(GemmaTokenizer(), info, pools);
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@
|
|||
#include <ostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <utility> // std::pair
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -37,6 +36,7 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -61,12 +61,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
: loader_(loader),
|
||||
inference_args_(inference),
|
||||
app_(app),
|
||||
pool_(app_.num_threads) {
|
||||
// For many-core, pinning workers to cores helps.
|
||||
if (app_.num_threads > 10) {
|
||||
PinWorkersToCores(pool_);
|
||||
}
|
||||
|
||||
pools_(app_.max_clusters, app_.num_threads) {
|
||||
AbortIfInvalidArgs(inference_args_);
|
||||
|
||||
if (const char* err = loader_.Validate()) {
|
||||
|
|
@ -74,7 +69,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||
} else {
|
||||
fprintf(stderr, "Loading model...\n");
|
||||
model_ = AllocateGemma(loader_, pool_);
|
||||
model_ = AllocateGemma(loader_, pools_);
|
||||
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
|
|
@ -234,7 +229,8 @@ void LogSpeedStats(double time_start, size_t total_tokens) {
|
|||
<< " [" << tok_sec << " tokens / sec" << "]\n";
|
||||
}
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
PerClusterPools& pools) {
|
||||
loader.Print(app.verbosity);
|
||||
inference.Print(app.verbosity);
|
||||
app.Print(app.verbosity);
|
||||
|
|
@ -242,22 +238,23 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
if (app.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
// TODO: replace hardware_concurrency with detected topology.
|
||||
std::cout << "Date & Time : " << dt
|
||||
<< "Hardware concurrency : "
|
||||
<< std::thread::hardware_concurrency() << "\n"
|
||||
<< "Instruction set : "
|
||||
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
|
||||
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
|
||||
char cpu100[100];
|
||||
if (hwy::platform::GetCpuString(cpu100)) {
|
||||
std::cout << "CPU : " << cpu100 << "\n";
|
||||
}
|
||||
std::cout << "Compiled config : " << CompiledConfig() << "\n"
|
||||
<< "Weight Type : "
|
||||
<< StringFromType(loader.Info().weight) << "\n"
|
||||
<< "EmbedderInput Type : "
|
||||
<< TypeName(EmbedderInputT()) << "\n";
|
||||
char cpu100[100] = "unknown";
|
||||
(void)hwy::platform::GetCpuString(cpu100);
|
||||
|
||||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
"CPU : %s\n"
|
||||
"CPU topology : %zux%zu, using %zux%zu\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Weight Type : %s\n"
|
||||
"EmbedderInput Type : %s\n",
|
||||
dt, cpu100, pools.CoresPerCluster().size(),
|
||||
pools.CoresPerCluster()[0].Count(), pools.Outer().NumWorkers(),
|
||||
pools.Inner(0).NumWorkers(),
|
||||
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
|
||||
CompiledConfig(), StringFromType(loader.Info().weight),
|
||||
TypeName(EmbedderInputT()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@
|
|||
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ class GemmaEnv {
|
|||
// Controls overall behavior of the app.
|
||||
AppArgs app_;
|
||||
// Thread pool for running inference.
|
||||
hwy::ThreadPool pool_;
|
||||
PerClusterPools pools_;
|
||||
// Random number generator.
|
||||
std::mt19937 gen_;
|
||||
// The model to run inference on.
|
||||
|
|
@ -111,7 +111,8 @@ class GemmaEnv {
|
|||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
PerClusterPools& pools);
|
||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ cc_binary(
|
|||
"//:args",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:threading",
|
||||
"//:tokenizer",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
|
|
|
|||
|
|
@ -26,12 +26,14 @@
|
|||
#include "gemma/tokenizer.h"
|
||||
#include "util/app.h" // LoaderArgs
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
loader.Help();
|
||||
return 0;
|
||||
|
|
@ -41,8 +43,8 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount());
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||
size_t pos = 0; // KV Cache position
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@
|
|||
#include "ops/matmul-inl.h"
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/bit_set.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
|
@ -575,143 +575,28 @@ HWY_NOINLINE void TransformerLayer(
|
|||
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
|
||||
// to their num_queries.
|
||||
class PrefillState {
|
||||
// TODO: move helper functions, also those in app.h, to a threading header
|
||||
using LPS = hwy::LogicalProcessorSet;
|
||||
LPS Intersection(const LPS& big_set, const LPS& small_set) {
|
||||
LPS both_set;
|
||||
// Reduce expected work by iterating over the smaller set.
|
||||
small_set.Foreach([&big_set, &both_set](size_t idx) {
|
||||
if (big_set.Get(idx)) both_set.Set(idx);
|
||||
});
|
||||
return both_set;
|
||||
}
|
||||
|
||||
std::vector<size_t> CoresInLPS(const LPS& cluster) {
|
||||
std::vector<size_t> cores;
|
||||
cores.reserve(cluster.Count());
|
||||
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
|
||||
return cores;
|
||||
}
|
||||
|
||||
// For each cluster (shared L3 cache), a bitset of cores.
|
||||
using CoresPerCluster = std::vector<LPS>;
|
||||
|
||||
// Returns empty if detection failed.
|
||||
CoresPerCluster DetectClusters() {
|
||||
CoresPerCluster clusters;
|
||||
// Which processors are not disabled via OS, taskset, or numactl.
|
||||
LPS enabled;
|
||||
// If we don't know, better to use just a single inner pool rather than risk
|
||||
// oversubscribing to enabled cores.
|
||||
if (!GetThreadAffinity(enabled)) return clusters;
|
||||
|
||||
hwy::Topology topology;
|
||||
if (topology.packages.empty()) return clusters;
|
||||
|
||||
// For each cluster = outer, the cores that will be used for an inner pool.
|
||||
CoresPerCluster inner_lps;
|
||||
for (const hwy::Topology::Package& package : topology.packages) {
|
||||
for (const hwy::Topology::Cluster& cluster : package.clusters) {
|
||||
// Only use enabled cores, and only add if not empty.
|
||||
const LPS lps = Intersection(enabled, cluster.lps);
|
||||
if (lps.Any()) clusters.push_back(lps);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by descending number of enabled cores, so that we preferentially
|
||||
// use the largest clusters.
|
||||
std::sort(clusters.begin(), clusters.end(),
|
||||
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
|
||||
|
||||
return clusters;
|
||||
}
|
||||
|
||||
// Returns false if the main pool should be reused instead.
|
||||
bool AssignInnerPoolsToClusters(const size_t num_queries) {
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
|
||||
CoresPerCluster inner_lps = DetectClusters();
|
||||
// If we have more outer workers than queries, discard the excess.
|
||||
if (inner_lps.size() > num_queries) inner_lps.resize(num_queries);
|
||||
// If we're not going to create multiple pools, avoid the overhead of
|
||||
// re-pinning (60 ms) and reuse the main pool.
|
||||
if (inner_lps.size() <= 1) return false;
|
||||
|
||||
// Before creating new threads, stop the old ones from spinning. Caller is
|
||||
// responsible for undoing this by calling `ResumeMainSpinning`.
|
||||
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
|
||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(inner_lps.size());
|
||||
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
|
||||
HWY_ASSERT(inner_pools_.empty());
|
||||
for (const LPS& inner : inner_lps) {
|
||||
inner_pools_.push_back(new hwy::ThreadPool(inner.Count()));
|
||||
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
}
|
||||
|
||||
// For each inner pool, pin their threads AND the associated outer thread
|
||||
// to the enabled cores in the cluster.
|
||||
outer_pool_->Run(
|
||||
0, inner_lps.size(),
|
||||
[this, &inner_lps](uint64_t outer, size_t outer_thread) {
|
||||
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||
const std::vector<size_t> cores = CoresInLPS(inner_lps[outer]);
|
||||
|
||||
inner_pools_[outer]->Run(
|
||||
0, cores.size(), [&cores](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each inner has one task
|
||||
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||
});
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void ReuseMainPoolAsInner() {
|
||||
// Still allocate an empty pool to simplify Prefill().
|
||||
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
|
||||
|
||||
HWY_ASSERT(inner_pools_.empty());
|
||||
inner_pools_.push_back(main_pool_);
|
||||
}
|
||||
|
||||
public:
|
||||
// Creates pools. AllocateActivations must still be called separately; it has
|
||||
// a template argument.
|
||||
PrefillState(hwy::ThreadPool& main_pool, size_t num_queries)
|
||||
: main_pool_(&main_pool) {
|
||||
PROFILER_ZONE("Init.Prefill.Ctor");
|
||||
if (!AssignInnerPoolsToClusters(num_queries)) {
|
||||
ReuseMainPoolAsInner();
|
||||
}
|
||||
}
|
||||
|
||||
~PrefillState() {
|
||||
for (hwy::ThreadPool* p : inner_pools_) {
|
||||
if (p != main_pool_) delete p;
|
||||
}
|
||||
}
|
||||
|
||||
// `tbatch_size` is the number of tokens from one query to prefill at a time.
|
||||
template <class TConfig>
|
||||
void AllocateActivations(size_t num_queries, size_t tbatch_size) {
|
||||
PROFILER_ZONE("Init.Prefill.AllocateActivations");
|
||||
|
||||
const size_t outer_workers = outer_pool_->NumWorkers();
|
||||
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
|
||||
|
||||
void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) {
|
||||
PROFILER_ZONE("Init.Prefill");
|
||||
HWY_ASSERT(num_queries != 0);
|
||||
HWY_ASSERT(activations_.empty()); // only call once.
|
||||
activations_.resize(outer_workers);
|
||||
|
||||
if (outer_workers == 1) {
|
||||
// Allocate one activation per query, not outer worker, because the common
|
||||
// case is a single query. If we allocate the lesser of the two, it is
|
||||
// unclear how to choose an unused activation in Prefill.
|
||||
activations_.resize(num_queries);
|
||||
|
||||
if (num_queries == 1) {
|
||||
activations_[0].Allocate<TConfig>(tbatch_size);
|
||||
} else {
|
||||
// Allocating in parallel can save 30 ms.
|
||||
main_pool_->Run(0, outer_workers,
|
||||
[this, tbatch_size](uint64_t task, size_t /*thread*/) {
|
||||
activations_[task].Allocate<TConfig>(tbatch_size);
|
||||
});
|
||||
// Allocating in parallel can save 30 ms. We might have more workers than
|
||||
// queries/tasks, so do not check the `thread` argument.
|
||||
pools.Outer().Run(0, num_queries,
|
||||
[this, tbatch_size](uint64_t qi, size_t /*thread*/) {
|
||||
activations_[qi].Allocate<TConfig>(tbatch_size);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -721,7 +606,7 @@ class PrefillState {
|
|||
const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const KVCaches& kv_caches) {
|
||||
const KVCaches& kv_caches, PerClusterPools& pools) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = prompts.size();
|
||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
|
|
@ -729,10 +614,10 @@ class PrefillState {
|
|||
|
||||
// For each query (parallel): an outer worker processes all its tokens.
|
||||
// `qi` is relative to the batch, not the global query index.
|
||||
outer_pool_->Run(
|
||||
pools.Outer().Run(
|
||||
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
|
||||
Activations& activations = activations_[qthread];
|
||||
hwy::ThreadPool& inner_pool = *inner_pools_[qthread];
|
||||
Activations& activations = activations_[qi];
|
||||
hwy::ThreadPool& inner_pool = pools.Inner(qthread);
|
||||
|
||||
// Single query at a time, so pass a slice of the KV cache because
|
||||
// GemmaAttention will only access the first.
|
||||
|
|
@ -768,29 +653,8 @@ class PrefillState {
|
|||
});
|
||||
}
|
||||
|
||||
// Stops spinning in our pools and resume spinning in main_pool_.
|
||||
void ResumeMainSpinning() {
|
||||
// If we didn't create a new inner pool, we didn't stop spinning on the
|
||||
// main pool, so nothing to do here.
|
||||
if (inner_pools_[0] == main_pool_) return;
|
||||
|
||||
for (hwy::ThreadPool* p : inner_pools_) {
|
||||
p->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
}
|
||||
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool* main_pool_;
|
||||
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
|
||||
// Holds a single pointer equal to main_pool_, or new allocations; in either
|
||||
// case, size() is equal to outer_pool_->NumWorkers(). The first case avoids
|
||||
// allocation overhead for the common case of a single query.
|
||||
std::vector<hwy::ThreadPool*> inner_pools_;
|
||||
|
||||
// size() == outer_pool_->NumWorkers(); filled by AllocateActivations.
|
||||
std::vector<Activations> activations_;
|
||||
std::vector<Activations> activations_; // One per query, filled by Init.
|
||||
};
|
||||
|
||||
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
|
||||
|
|
@ -945,12 +809,15 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts, const size_t pos,
|
||||
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
const CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
|
||||
// TODO: remove once all parallel sections support hierarchical parallelism.
|
||||
hwy::ThreadPool& pool = pools.Inner(0);
|
||||
|
||||
const size_t num_queries = prompts.size();
|
||||
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
|
||||
HWY_ASSERT(num_queries <= activations.x.BatchSize());
|
||||
|
|
@ -984,14 +851,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
const size_t prefill_per_query = min_prompt_size - 1;
|
||||
double prefill_start;
|
||||
{
|
||||
PrefillState prefill(pool, num_queries);
|
||||
prefill.AllocateActivations<TConfig>(num_queries,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
// TODO: move to Gemma, reuse across calls to Generate.
|
||||
PrefillState prefill;
|
||||
prefill.Init<TConfig>(num_queries, runtime_config.prefill_tbatch_size,
|
||||
pools);
|
||||
prefill_start = hwy::platform::Now();
|
||||
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
|
||||
weights, runtime_config, kv_caches);
|
||||
weights, runtime_config, kv_caches, pools);
|
||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||
prefill.ResumeMainSpinning();
|
||||
}
|
||||
|
||||
size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
|
||||
|
|
@ -1051,7 +918,7 @@ template <class TConfig>
|
|||
void GenerateSingleT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
const size_t num_queries = 1;
|
||||
const size_t qbatch_start = 0;
|
||||
|
||||
|
|
@ -1062,14 +929,14 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
|
|||
const KVCaches kv_caches{&kv_cache, num_queries};
|
||||
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
|
||||
qbatch_start, kv_caches, pool, timing_info);
|
||||
qbatch_start, kv_caches, pools, timing_info);
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void GenerateBatchT(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts, size_t pos,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_ASSERT(prompts.size() == kv_caches.size());
|
||||
// Griffin does not support query batching.
|
||||
|
|
@ -1089,7 +956,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
qbatch_size);
|
||||
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
|
||||
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
|
||||
pos, qbatch_start, qbatch_kv, pool, timing_info);
|
||||
pos, qbatch_start, qbatch_kv, pools, timing_info);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1102,18 +969,18 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
|
|||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) {
|
||||
KVCache& kv_cache, PerClusterPools& pools, TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
|
||||
(weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info);
|
||||
(weights_u8, runtime_config, prompt, pos, kv_cache, pools, timing_info);
|
||||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
GEMMA_CONFIG, const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
|
||||
size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
size_t pos, const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
|
||||
(weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info);
|
||||
(weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -28,23 +28,25 @@
|
|||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
||||
const ModelInfo& info, PerClusterPools& pools)
|
||||
: pools_(pools), tokenizer_(tokenizer_path), info_(info) {
|
||||
weights_u8_ =
|
||||
LoadCompressedWeights(weights, info.model, info.weight, pools_.Inner(0));
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||
hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) {
|
||||
PerClusterPools& pools)
|
||||
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
weights_u8_ =
|
||||
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
||||
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(info.model,
|
||||
pools_.Inner(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
|
|
@ -60,12 +62,12 @@ Gemma::~Gemma() {
|
|||
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool, \
|
||||
KVCache& kv_cache, PerClusterPools& pools, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const MultiplePromptsTokens& prompts, size_t pos, \
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool, \
|
||||
const KVCaches& kv_caches, PerClusterPools& pools, \
|
||||
TimingInfo& timing_info);
|
||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||
|
||||
|
|
@ -75,9 +77,9 @@ struct GenerateSingleT {
|
|||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
|
||||
PerClusterPools& pools, TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
|
||||
pool, timing_info);
|
||||
pools, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -86,36 +88,36 @@ struct GenerateBatchT {
|
|||
void operator()(const ByteStorageT& weights_u8,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts, size_t pos,
|
||||
const KVCaches& kv_caches, hwy::ThreadPool& pool,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
|
||||
kv_caches, pool, timing_info);
|
||||
kv_caches, pools, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompt, start_pos,
|
||||
kv_cache, pool_, timing_info);
|
||||
kv_cache, pools_, timing_info);
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
pools_.StopSpinning();
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
const MultiplePromptsTokens& prompts,
|
||||
size_t start_pos, const KVCaches& kv_caches,
|
||||
TimingInfo& timing_info) {
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||
pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
|
||||
runtime_config, prompts, start_pos,
|
||||
kv_caches, pool_, timing_info);
|
||||
kv_caches, pools_, timing_info);
|
||||
|
||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||
pools_.StopSpinning();
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
|
@ -56,8 +57,8 @@ using SampleFunc = std::function<int(const float*, size_t)>;
|
|||
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputFunc =
|
||||
std::function<void(size_t, size_t, const std::string&, int, const float*, size_t)>;
|
||||
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
|
||||
int, const float*, size_t)>;
|
||||
|
||||
struct RuntimeConfig {
|
||||
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
|
||||
|
|
@ -121,11 +122,11 @@ using KVCaches = hwy::Span<KVCache>;
|
|||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
hwy::ThreadPool& pool);
|
||||
PerClusterPools& pools);
|
||||
|
||||
// Allocates weights, caller is responsible for filling them.
|
||||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||
hwy::ThreadPool& pool);
|
||||
PerClusterPools& pools);
|
||||
~Gemma();
|
||||
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
|
|
@ -140,7 +141,7 @@ class Gemma {
|
|||
const KVCaches& kv_caches, TimingInfo& timing_info);
|
||||
|
||||
private:
|
||||
hwy::ThreadPool& pool_;
|
||||
PerClusterPools& pools_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
|
|
@ -155,14 +156,6 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt);
|
||||
|
||||
// DEPRECATED, call Gemma::Generate directly.
|
||||
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
|
||||
TimingInfo& timing_info) {
|
||||
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
|
|
|
|||
21
gemma/run.cc
21
gemma/run.cc
|
|
@ -27,6 +27,7 @@
|
|||
#include "gemma/gemma.h" // Gemma
|
||||
#include "util/app.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
|
|
@ -75,9 +76,9 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const InferenceArgs& args, int verbosity,
|
||||
const AcceptFunc& accept_token, std::string& eot_line) {
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
|
||||
int verbosity, const AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -172,13 +173,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
|||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
hwy::ThreadPool pool(app.num_threads);
|
||||
// For many-core, pinning workers to cores helps.
|
||||
if (app.num_threads > 10) {
|
||||
PinWorkersToCores(pool);
|
||||
}
|
||||
// Note that num_threads is an upper bound; we also limit to the number of
|
||||
// detected and enabled cores.
|
||||
PerClusterPools pools(app.max_clusters, app.num_threads);
|
||||
|
||||
Gemma model = CreateGemma(loader, pool);
|
||||
Gemma model = CreateGemma(loader, pools);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
|
||||
|
||||
|
|
@ -205,11 +204,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, inference, app);
|
||||
ShowConfig(loader, inference, app, pools);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
|
||||
ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(),
|
||||
ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(),
|
||||
app.eot_line);
|
||||
}
|
||||
|
||||
|
|
|
|||
107
util/app.h
107
util/app.h
|
|
@ -23,20 +23,12 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
|
||||
#if HWY_OS_LINUX
|
||||
#include <sched.h>
|
||||
#endif // HWY_OS_LINUX
|
||||
#include "hwy/base.h" // HWY_IS_ASAN
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -58,86 +50,14 @@ static inline const char* CompiledConfig() {
|
|||
}
|
||||
}
|
||||
|
||||
static inline std::vector<size_t> LpsToCpus(
|
||||
const hwy::LogicalProcessorSet& lps) {
|
||||
std::vector<size_t> cpus;
|
||||
cpus.reserve(lps.Count());
|
||||
lps.Foreach([&cpus](size_t lp) { cpus.push_back(lp); });
|
||||
return cpus;
|
||||
}
|
||||
|
||||
static inline std::vector<size_t> AssignCpusFromTopology(
|
||||
const hwy::Topology& topology, const size_t num_workers) {
|
||||
// Assign CPUs to workers 0 to num_workers - 1 based on the topology.
|
||||
// The assignments are done in a round-robin fashion across all clusters and
|
||||
// Cores.
|
||||
// For example, if we have 4 clusters, the assignments will be:
|
||||
// Thread 0 -> Cluster 0, Core 0
|
||||
// Thread 1 -> Cluster 1, Core 0
|
||||
// Thread 2 -> Cluster 2, Core 0
|
||||
// Thread 3 -> Cluster 3, Core 0
|
||||
// Thread 4 -> Cluster 0, Core 1
|
||||
// Thread 5 -> Cluster 1, Core 1
|
||||
// ... and so on.
|
||||
//
|
||||
// This would result in the least amount of sharing of the last-level
|
||||
// cache slices. All assignments are made from Package 0.
|
||||
std::vector<std::vector<size_t>> clusters;
|
||||
for (auto& package : topology.packages) {
|
||||
for (auto& cluster : package.clusters) {
|
||||
clusters.push_back(LpsToCpus(cluster.lps));
|
||||
}
|
||||
}
|
||||
std::vector<size_t> assigned_cpus;
|
||||
assigned_cpus.reserve(num_workers);
|
||||
for (size_t i = 0; i < num_workers; ++i) {
|
||||
size_t cluster_index = i % clusters.size();
|
||||
size_t cpu_index = (i / clusters.size()) % clusters[cluster_index].size();
|
||||
assigned_cpus.push_back(clusters[cluster_index][cpu_index]);
|
||||
}
|
||||
return assigned_cpus;
|
||||
}
|
||||
|
||||
static inline void PinWorkersToCores(hwy::ThreadPool& pool) {
|
||||
// Use topology to pin workers to cores if available.
|
||||
hwy::Topology topology;
|
||||
if (!topology.packages.empty()) {
|
||||
std::vector<size_t> assigned_cpus =
|
||||
AssignCpusFromTopology(topology, pool.NumWorkers());
|
||||
pool.Run(0, pool.NumWorkers(),
|
||||
[&assigned_cpus](uint64_t /*task*/, size_t thread) {
|
||||
hwy::PinThreadToLogicalProcessor(assigned_cpus[thread]);
|
||||
});
|
||||
} else {
|
||||
pool.Run(0, pool.NumWorkers(), [](uint64_t /*task*/, size_t thread) {
|
||||
hwy::PinThreadToLogicalProcessor(thread);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
class AppArgs : public ArgsBase<AppArgs> {
|
||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||
|
||||
void ChooseNumThreads() {
|
||||
if (num_threads == kDefaultNumThreads) {
|
||||
// This is a rough heuristic, replace with something better in the future.
|
||||
num_threads = GetSupportedThreadCount();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
AppArgs(int argc, char* argv[]) {
|
||||
InitAndParse(argc, argv);
|
||||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
static inline size_t GetSupportedThreadCount() {
|
||||
return HWY_MIN(hwy::ThreadPool::MaxThreads(), kMaxThreads);
|
||||
}
|
||||
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
|
||||
Path log; // output
|
||||
int verbosity;
|
||||
size_t num_threads;
|
||||
size_t num_threads; // divided among the detected clusters
|
||||
size_t max_clusters;
|
||||
std::string eot_line;
|
||||
|
||||
template <class Visitor>
|
||||
|
|
@ -147,11 +67,10 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
"Number of threads to use.\n Default = Estimate of the "
|
||||
"number of supported concurrent threads.",
|
||||
2);
|
||||
visitor(num_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
|
|
@ -232,14 +151,14 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
};
|
||||
|
||||
static inline Gemma CreateGemma(const LoaderArgs& loader,
|
||||
hwy::ThreadPool& pool) {
|
||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool);
|
||||
PerClusterPools& pools) {
|
||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
|
||||
}
|
||||
|
||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||
hwy::ThreadPool& pool) {
|
||||
PerClusterPools& pools) {
|
||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||
loader.Info(), pool);
|
||||
loader.Info(), pools);
|
||||
}
|
||||
|
||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,201 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Shared between various frontends.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an
|
||||
// 'outer' thread pool with one worker per cluster.
|
||||
//
|
||||
// Useful for hierarchical parallelism, which makes sense when there are few
|
||||
// but large tasks which should be parallelized by workers sharing a cache.
|
||||
// This also implies lower latency for barrier synchronization of those workers.
|
||||
class PerClusterPools {
|
||||
using LPS = hwy::LogicalProcessorSet;
|
||||
|
||||
static inline std::vector<size_t> CoresInLPS(const LPS& cluster) {
|
||||
std::vector<size_t> cores;
|
||||
cores.reserve(cluster.Count());
|
||||
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
|
||||
return cores;
|
||||
}
|
||||
|
||||
using CoreBitSets = std::vector<LPS>;
|
||||
|
||||
// Returns empty if detection failed.
|
||||
CoreBitSets DetectCoresPerCluster() {
|
||||
CoreBitSets clusters;
|
||||
if (!have_threading_support_) return clusters;
|
||||
|
||||
// Which processors are not disabled via OS, taskset, or numactl.
|
||||
LPS enabled;
|
||||
// If we don't know, better to abort rather than risk oversubscribing.
|
||||
if (!GetThreadAffinity(enabled)) return clusters;
|
||||
|
||||
hwy::Topology topology;
|
||||
if (topology.packages.empty()) return clusters;
|
||||
|
||||
// Merge all clusters into one set, as a stopgap to emulate gemma-inl's
|
||||
// prior single pool.
|
||||
// TODO: remove once MatMul supports hierarchical parallelism.
|
||||
LPS all;
|
||||
|
||||
// For each cluster, add its enabled *cores*.
|
||||
for (const hwy::Topology::Package& package : topology.packages) {
|
||||
for (const hwy::Topology::Cluster& cluster : package.clusters) {
|
||||
cluster.lps.Foreach([&](size_t lp) {
|
||||
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
|
||||
all.Set(lp);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/* code to reinstate:
|
||||
for (const hwy::Topology::Cluster& cluster : package.clusters) {
|
||||
// Only use enabled *cores*, and only add if not empty.
|
||||
cluster.lps.Foreach([&](size_t lp) {
|
||||
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
|
||||
all.Set(lp);
|
||||
}
|
||||
});
|
||||
if (lps.Any()) clusters.push_back(lps);
|
||||
}
|
||||
*/
|
||||
}
|
||||
if (all.Any()) clusters.push_back(all);
|
||||
|
||||
// Sort by descending number of enabled cores, so that we preferentially
|
||||
// use the largest clusters.
|
||||
std::sort(clusters.begin(), clusters.end(),
|
||||
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
|
||||
|
||||
return clusters;
|
||||
}
|
||||
|
||||
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
||||
outer_pool_.SetWaitMode(wait_mode);
|
||||
for (auto& inner : inner_pools_) {
|
||||
inner->SetWaitMode(wait_mode);
|
||||
}
|
||||
}
|
||||
|
||||
// The defaults for `AppArgs` `max_clusters` and `num_threads` are zero, which
|
||||
// means no limit.
|
||||
size_t CapIfNonzero(size_t detected, size_t user_max_or_zero) {
|
||||
if (!have_threading_support_) return 0;
|
||||
return (user_max_or_zero == 0) ? detected
|
||||
: HWY_MIN(detected, user_max_or_zero);
|
||||
}
|
||||
|
||||
public:
|
||||
// PerClusterPools supports spin waits (see StartSpinning below). To prevent
|
||||
// drastic slowdowns caused by excessive user-specified thread counts, which
|
||||
// result in threads not running on their own core, we only allow for
|
||||
// *upper bounds* on the number of clusters and threads. The actual number of
|
||||
// clusters and threads are still limited by the detected topology.
|
||||
PerClusterPools(size_t max_clusters, size_t max_threads)
|
||||
: have_threading_support_(hwy::HaveThreadingSupport()),
|
||||
cores_per_cluster_(DetectCoresPerCluster()),
|
||||
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
|
||||
// Topology detection failed - it currently requires Linux.
|
||||
if (cores_per_cluster_.empty()) {
|
||||
// Create a single inner pool with up to TotalLogicalProcessors() / 2
|
||||
// workers, further limited by `max_threads` if nonzero, and then pin to
|
||||
// the first N processors, which are typically on the first socket.
|
||||
const size_t num_threads =
|
||||
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
|
||||
fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads);
|
||||
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
||||
if (num_threads > 1) {
|
||||
inner_pools_.back()->Run(0, num_threads,
|
||||
[](uint64_t /*task*/, size_t thread) {
|
||||
hwy::PinThreadToLogicalProcessor(thread);
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t max_per_inner = max_threads / outer_pool_.NumWorkers();
|
||||
for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) {
|
||||
const size_t num_threads =
|
||||
CapIfNonzero(cores_per_cluster_[outer].Count(), max_per_inner);
|
||||
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
|
||||
}
|
||||
|
||||
// For each inner pool, pin their threads AND the associated outer thread
|
||||
// (the one calling inner.Run()) to the enabled cores in the cluster.
|
||||
outer_pool_.Run(
|
||||
0, outer_pool_.NumWorkers(),
|
||||
[this](uint64_t outer, size_t outer_thread) {
|
||||
HWY_ASSERT(outer == outer_thread); // each outer has one task
|
||||
hwy::ThreadPool& inner = *inner_pools_[outer];
|
||||
|
||||
const std::vector<size_t> cores =
|
||||
CoresInLPS(cores_per_cluster_[outer]);
|
||||
// May have been capped by max_threads.
|
||||
HWY_ASSERT(inner.NumWorkers() <= cores.size());
|
||||
|
||||
inner.Run(0, inner.NumWorkers(),
|
||||
[&cores](uint64_t task, size_t thread) {
|
||||
HWY_ASSERT(task == thread); // each inner has one task
|
||||
hwy::PinThreadToLogicalProcessor(cores[task]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Spinning reduces the latency of barrier synchronization, but wastes lots of
|
||||
// energy for long waits, so only do it during generation. This might also be
|
||||
// unsafe in virtualized environments because we require threads to be running
|
||||
// on their own core and thus responsive to the barrier synchronization.
|
||||
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
|
||||
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
|
||||
|
||||
// Bitset of cores, one per cluster, or empty if detection failed. Useful for
|
||||
// displaying the topology.
|
||||
const CoreBitSets& CoresPerCluster() const { return cores_per_cluster_; }
|
||||
|
||||
hwy::ThreadPool& Outer() { return outer_pool_; }
|
||||
hwy::ThreadPool& Inner(size_t outer) {
|
||||
HWY_ASSERT(outer < Outer().NumWorkers());
|
||||
return *inner_pools_[outer];
|
||||
}
|
||||
|
||||
private:
|
||||
bool have_threading_support_;
|
||||
CoreBitSets cores_per_cluster_;
|
||||
hwy::ThreadPool outer_pool_;
|
||||
// hwy::ThreadPool is unfortunately not marked as movable, so we have to use
|
||||
// unique_ptr.
|
||||
std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
|
||||
Loading…
Reference in New Issue