1.1x prefill speedup, revamp threading in preparation for hierarchical parallelism.

Limit thread counts to detected. Add max_clusters arg.
Update detection logic to check for smt0 - previously we pinned to some siblings.

PiperOrigin-RevId: 659755311
This commit is contained in:
Jan Wassenberg 2024-08-05 18:49:16 -07:00 committed by Copybara-Service
parent 1617e1a33d
commit 5e433e774a
12 changed files with 338 additions and 342 deletions

View File

@ -148,6 +148,16 @@ cc_library(
],
)
cc_library(
name = "threading",
hdrs = ["util/threading.h"],
deps = [
"@hwy//:hwy",
"@hwy//:thread_pool",
"@hwy//:topology",
],
)
cc_library(
name = "gemma_lib",
srcs = [
@ -192,6 +202,7 @@ cc_library(
":tokenizer",
":kv_cache",
":weights",
":threading",
"//compression:io",
"@hwy//:hwy",
"@hwy//:bit_set",
@ -248,6 +259,7 @@ cc_library(
":cross_entropy",
":gemma_lib",
":kv_cache",
":threading",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
"//compression:compress",
@ -286,10 +298,9 @@ cc_binary(
":benchmark_helper",
":common",
":gemma_lib",
":threading",
# Placeholder for internal dep, do not remove.,
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
],
@ -486,6 +497,7 @@ cc_test(
":optimizer",
":prompt",
":sampler",
":threading",
":weights",
"@googletest//:gtest_main",
"@hwy//:thread_pool",

View File

@ -29,12 +29,14 @@
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
hwy::ThreadPool pool(0);
PerClusterPools pools(1, 1);
hwy::ThreadPool& pool = pools.Inner(0);
std::mt19937 gen(42);
const ModelInfo info = {
@ -54,7 +56,7 @@ TEST(OptimizeTest, GradientDescent) {
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
Gemma gemma(GemmaTokenizer(), info, pool);
Gemma gemma(GemmaTokenizer(), info, pools);
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;

View File

@ -25,7 +25,6 @@
#include <ostream>
#include <random>
#include <string>
#include <thread> // NOLINT
#include <utility> // std::pair
#include <vector>
@ -37,6 +36,7 @@
#include "gemma/kv_cache.h"
#include "util/app.h"
#include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -61,12 +61,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
: loader_(loader),
inference_args_(inference),
app_(app),
pool_(app_.num_threads) {
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
PinWorkersToCores(pool_);
}
pools_(app_.max_clusters, app_.num_threads) {
AbortIfInvalidArgs(inference_args_);
if (const char* err = loader_.Validate()) {
@ -74,7 +69,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(loader_, pool_);
model_ = AllocateGemma(loader_, pools_);
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1);
@ -234,7 +229,8 @@ void LogSpeedStats(double time_start, size_t total_tokens) {
<< " [" << tok_sec << " tokens / sec" << "]\n";
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
PerClusterPools& pools) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
@ -242,22 +238,23 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
if (app.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
// TODO: replace hardware_concurrency with detected topology.
std::cout << "Date & Time : " << dt
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
<< hwy::TargetName(hwy::DispatchedTarget()) << " ("
<< hwy::VectorBytes() * 8 << " bits)" << "\n";
char cpu100[100];
if (hwy::platform::GetCpuString(cpu100)) {
std::cout << "CPU : " << cpu100 << "\n";
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< StringFromType(loader.Info().weight) << "\n"
<< "EmbedderInput Type : "
<< TypeName(EmbedderInputT()) << "\n";
char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100);
fprintf(stderr,
"Date & Time : %s" // dt includes \n
"CPU : %s\n"
"CPU topology : %zux%zu, using %zux%zu\n"
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n"
"Weight Type : %s\n"
"EmbedderInput Type : %s\n",
dt, cpu100, pools.CoresPerCluster().size(),
pools.CoresPerCluster()[0].Count(), pools.Outer().NumWorkers(),
pools.Inner(0).NumWorkers(),
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
CompiledConfig(), StringFromType(loader.Info().weight),
TypeName(EmbedderInputT()));
}
}

View File

@ -26,8 +26,8 @@
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
@ -98,7 +98,7 @@ class GemmaEnv {
// Controls overall behavior of the app.
AppArgs app_;
// Thread pool for running inference.
hwy::ThreadPool pool_;
PerClusterPools pools_;
// Random number generator.
std::mt19937 gen_;
// The model to run inference on.
@ -111,7 +111,8 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
PerClusterPools& pools);
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
} // namespace gcpp

View File

@ -15,6 +15,7 @@ cc_binary(
"//:args",
"//:common",
"//:gemma_lib",
"//:threading",
"//:tokenizer",
"@hwy//:hwy",
"@hwy//:thread_pool",

View File

@ -26,12 +26,14 @@
#include "gemma/tokenizer.h"
#include "util/app.h" // LoaderArgs
#include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
loader.Help();
return 0;
@ -41,8 +43,8 @@ int main(int argc, char** argv) {
}
// Instantiate model and KV Cache
hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount());
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads);
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache =
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
size_t pos = 0; // KV Cache position

View File

@ -42,11 +42,11 @@
#include "ops/matmul-inl.h"
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/bit_set.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
@ -575,142 +575,27 @@ HWY_NOINLINE void TransformerLayer(
// NOT reused across calls to GenerateSingle/GenerateBatch so that we can adapt
// to their num_queries.
class PrefillState {
// TODO: move helper functions, also those in app.h, to a threading header
using LPS = hwy::LogicalProcessorSet;
LPS Intersection(const LPS& big_set, const LPS& small_set) {
LPS both_set;
// Reduce expected work by iterating over the smaller set.
small_set.Foreach([&big_set, &both_set](size_t idx) {
if (big_set.Get(idx)) both_set.Set(idx);
});
return both_set;
}
std::vector<size_t> CoresInLPS(const LPS& cluster) {
std::vector<size_t> cores;
cores.reserve(cluster.Count());
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
return cores;
}
// For each cluster (shared L3 cache), a bitset of cores.
using CoresPerCluster = std::vector<LPS>;
// Returns empty if detection failed.
CoresPerCluster DetectClusters() {
CoresPerCluster clusters;
// Which processors are not disabled via OS, taskset, or numactl.
LPS enabled;
// If we don't know, better to use just a single inner pool rather than risk
// oversubscribing to enabled cores.
if (!GetThreadAffinity(enabled)) return clusters;
hwy::Topology topology;
if (topology.packages.empty()) return clusters;
// For each cluster = outer, the cores that will be used for an inner pool.
CoresPerCluster inner_lps;
for (const hwy::Topology::Package& package : topology.packages) {
for (const hwy::Topology::Cluster& cluster : package.clusters) {
// Only use enabled cores, and only add if not empty.
const LPS lps = Intersection(enabled, cluster.lps);
if (lps.Any()) clusters.push_back(lps);
}
}
// Sort by descending number of enabled cores, so that we preferentially
// use the largest clusters.
std::sort(clusters.begin(), clusters.end(),
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
return clusters;
}
// Returns false if the main pool should be reused instead.
bool AssignInnerPoolsToClusters(const size_t num_queries) {
HWY_ASSERT(num_queries != 0);
CoresPerCluster inner_lps = DetectClusters();
// If we have more outer workers than queries, discard the excess.
if (inner_lps.size() > num_queries) inner_lps.resize(num_queries);
// If we're not going to create multiple pools, avoid the overhead of
// re-pinning (60 ms) and reuse the main pool.
if (inner_lps.size() <= 1) return false;
// Before creating new threads, stop the old ones from spinning. Caller is
// responsible for undoing this by calling `ResumeMainSpinning`.
main_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
outer_pool_ = std::make_unique<hwy::ThreadPool>(inner_lps.size());
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
HWY_ASSERT(inner_pools_.empty());
for (const LPS& inner : inner_lps) {
inner_pools_.push_back(new hwy::ThreadPool(inner.Count()));
inner_pools_.back()->SetWaitMode(hwy::PoolWaitMode::kSpin);
}
// For each inner pool, pin their threads AND the associated outer thread
// to the enabled cores in the cluster.
outer_pool_->Run(
0, inner_lps.size(),
[this, &inner_lps](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
const std::vector<size_t> cores = CoresInLPS(inner_lps[outer]);
inner_pools_[outer]->Run(
0, cores.size(), [&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
return true;
}
void ReuseMainPoolAsInner() {
// Still allocate an empty pool to simplify Prefill().
outer_pool_ = std::make_unique<hwy::ThreadPool>(1);
HWY_ASSERT(inner_pools_.empty());
inner_pools_.push_back(main_pool_);
}
public:
// Creates pools. AllocateActivations must still be called separately; it has
// a template argument.
PrefillState(hwy::ThreadPool& main_pool, size_t num_queries)
: main_pool_(&main_pool) {
PROFILER_ZONE("Init.Prefill.Ctor");
if (!AssignInnerPoolsToClusters(num_queries)) {
ReuseMainPoolAsInner();
}
}
~PrefillState() {
for (hwy::ThreadPool* p : inner_pools_) {
if (p != main_pool_) delete p;
}
}
// `tbatch_size` is the number of tokens from one query to prefill at a time.
template <class TConfig>
void AllocateActivations(size_t num_queries, size_t tbatch_size) {
PROFILER_ZONE("Init.Prefill.AllocateActivations");
const size_t outer_workers = outer_pool_->NumWorkers();
HWY_ASSERT(outer_workers != 0); // Otherwise activations_ is empty.
void Init(size_t num_queries, size_t tbatch_size, PerClusterPools& pools) {
PROFILER_ZONE("Init.Prefill");
HWY_ASSERT(num_queries != 0);
HWY_ASSERT(activations_.empty()); // only call once.
activations_.resize(outer_workers);
if (outer_workers == 1) {
// Allocate one activation per query, not outer worker, because the common
// case is a single query. If we allocate the lesser of the two, it is
// unclear how to choose an unused activation in Prefill.
activations_.resize(num_queries);
if (num_queries == 1) {
activations_[0].Allocate<TConfig>(tbatch_size);
} else {
// Allocating in parallel can save 30 ms.
main_pool_->Run(0, outer_workers,
[this, tbatch_size](uint64_t task, size_t /*thread*/) {
activations_[task].Allocate<TConfig>(tbatch_size);
// Allocating in parallel can save 30 ms. We might have more workers than
// queries/tasks, so do not check the `thread` argument.
pools.Outer().Run(0, num_queries,
[this, tbatch_size](uint64_t qi, size_t /*thread*/) {
activations_[qi].Allocate<TConfig>(tbatch_size);
});
}
}
@ -721,7 +606,7 @@ class PrefillState {
const size_t query_idx_start,
const CompressedWeights<TConfig>& weights,
const RuntimeConfig& runtime_config,
const KVCaches& kv_caches) {
const KVCaches& kv_caches, PerClusterPools& pools) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = prompts.size();
HWY_ASSERT(kv_caches.size() == num_queries);
@ -729,10 +614,10 @@ class PrefillState {
// For each query (parallel): an outer worker processes all its tokens.
// `qi` is relative to the batch, not the global query index.
outer_pool_->Run(
pools.Outer().Run(
0, num_queries, [&](const uint64_t qi, size_t qthread) HWY_ATTR {
Activations& activations = activations_[qthread];
hwy::ThreadPool& inner_pool = *inner_pools_[qthread];
Activations& activations = activations_[qi];
hwy::ThreadPool& inner_pool = pools.Inner(qthread);
// Single query at a time, so pass a slice of the KV cache because
// GemmaAttention will only access the first.
@ -768,29 +653,8 @@ class PrefillState {
});
}
// Stops spinning in our pools and resume spinning in main_pool_.
void ResumeMainSpinning() {
// If we didn't create a new inner pool, we didn't stop spinning on the
// main pool, so nothing to do here.
if (inner_pools_[0] == main_pool_) return;
for (hwy::ThreadPool* p : inner_pools_) {
p->SetWaitMode(hwy::PoolWaitMode::kBlock);
}
outer_pool_->SetWaitMode(hwy::PoolWaitMode::kBlock);
main_pool_->SetWaitMode(hwy::PoolWaitMode::kSpin);
}
private:
hwy::ThreadPool* main_pool_;
std::unique_ptr<hwy::ThreadPool> outer_pool_; // always allocated
// Holds a single pointer equal to main_pool_, or new allocations; in either
// case, size() is equal to outer_pool_->NumWorkers(). The first case avoids
// allocation overhead for the common case of a single query.
std::vector<hwy::ThreadPool*> inner_pools_;
// size() == outer_pool_->NumWorkers(); filled by AllocateActivations.
std::vector<Activations> activations_;
std::vector<Activations> activations_; // One per query, filled by Init.
};
// `tokens` is length `num_tokens * num_queries`. In autoregressive decode,
@ -945,12 +809,15 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, const size_t pos,
const size_t query_idx_start, const KVCaches& kv_caches,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
PerClusterPools& pools, TimingInfo& timing_info) {
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kVocabSize = TConfig::kVocabSize;
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
// TODO: remove once all parallel sections support hierarchical parallelism.
hwy::ThreadPool& pool = pools.Inner(0);
const size_t num_queries = prompts.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.BatchSize());
@ -984,14 +851,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
const size_t prefill_per_query = min_prompt_size - 1;
double prefill_start;
{
PrefillState prefill(pool, num_queries);
prefill.AllocateActivations<TConfig>(num_queries,
runtime_config.prefill_tbatch_size);
// TODO: move to Gemma, reuse across calls to Generate.
PrefillState prefill;
prefill.Init<TConfig>(num_queries, runtime_config.prefill_tbatch_size,
pools);
prefill_start = hwy::platform::Now();
prefill.Prefill<TConfig>(prompts, prefill_per_query, pos, query_idx_start,
weights, runtime_config, kv_caches);
weights, runtime_config, kv_caches, pools);
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
prefill.ResumeMainSpinning();
}
size_t interleaved_pos = (pos + prefill_per_query) * num_queries;
@ -1051,7 +918,7 @@ template <class TConfig>
void GenerateSingleT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
PerClusterPools& pools, TimingInfo& timing_info) {
const size_t num_queries = 1;
const size_t qbatch_start = 0;
@ -1062,14 +929,14 @@ void GenerateSingleT(const ByteStorageT& weights_u8,
const KVCaches kv_caches{&kv_cache, num_queries};
GenerateT<TConfig>(weights_u8, activations, runtime_config, prompts, pos,
qbatch_start, kv_caches, pool, timing_info);
qbatch_start, kv_caches, pools, timing_info);
}
template <class TConfig>
void GenerateBatchT(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t pos,
const KVCaches& kv_caches, hwy::ThreadPool& pool,
const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) {
HWY_ASSERT(prompts.size() == kv_caches.size());
// Griffin does not support query batching.
@ -1089,7 +956,7 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<TConfig>(weights_u8, activations, runtime_config, qbatch_prompts,
pos, qbatch_start, qbatch_kv, pool, timing_info);
pos, qbatch_start, qbatch_kv, pools, timing_info);
}
}
@ -1102,18 +969,18 @@ void GenerateBatchT(const ByteStorageT& weights_u8,
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) {
KVCache& kv_cache, PerClusterPools& pools, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_CONFIG>)
(weights_u8, runtime_config, prompt, pos, kv_cache, pool, timing_info);
(weights_u8, runtime_config, prompt, pos, kv_cache, pools, timing_info);
}
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
GEMMA_CONFIG, const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config, const MultiplePromptsTokens& prompts,
size_t pos, const KVCaches& kv_caches, hwy::ThreadPool& pool,
size_t pos, const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_CONFIG>)
(weights_u8, runtime_config, prompts, pos, kv_caches, pool, timing_info);
(weights_u8, runtime_config, prompts, pos, kv_caches, pools, timing_info);
}
#endif // HWY_ONCE

View File

@ -28,23 +28,25 @@
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
namespace gcpp {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
const ModelInfo& info, PerClusterPools& pools)
: pools_(pools), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ =
LoadCompressedWeights(weights, info.model, info.weight, pools_.Inner(0));
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) {
PerClusterPools& pools)
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ =
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(info.model,
pools_.Inner(0));
}
Gemma::~Gemma() {
@ -60,12 +62,12 @@ Gemma::~Gemma() {
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, \
const PromptTokens& prompt, size_t pos, \
KVCache& kv_cache, hwy::ThreadPool& pool, \
KVCache& kv_cache, PerClusterPools& pools, \
TimingInfo& timing_info); \
extern void GenerateBatch(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
const RuntimeConfig& runtime_config, \
const MultiplePromptsTokens& prompts, size_t pos, \
const KVCaches& kv_caches, hwy::ThreadPool& pool, \
const KVCaches& kv_caches, PerClusterPools& pools, \
TimingInfo& timing_info);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
@ -75,9 +77,9 @@ struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) const {
PerClusterPools& pools, TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, kv_cache,
pool, timing_info);
pools, timing_info);
}
};
@ -86,36 +88,36 @@ struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8,
const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts, size_t pos,
const KVCaches& kv_caches, hwy::ThreadPool& pool,
const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, runtime_config, prompts, pos,
kv_caches, pool, timing_info);
kv_caches, pools, timing_info);
}
};
void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
pools_.StartSpinning();
CallForModelAndWeight<GenerateSingleT>(info_.model, info_.weight, weights_u8_,
runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info);
kv_cache, pools_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
pools_.StopSpinning();
}
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const MultiplePromptsTokens& prompts,
size_t start_pos, const KVCaches& kv_caches,
TimingInfo& timing_info) {
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
pools_.StartSpinning();
CallForModelAndWeight<GenerateBatchT>(info_.model, info_.weight, weights_u8_,
runtime_config, prompts, start_pos,
kv_caches, pool_, timing_info);
kv_caches, pools_, timing_info);
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
pools_.StopSpinning();
}
} // namespace gcpp

View File

@ -27,6 +27,7 @@
#include "gemma/common.h"
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports
@ -56,8 +57,8 @@ using SampleFunc = std::function<int(const float*, size_t)>;
// - layer index (or -1 for global outputs), e.g. "blocks" exposes x per-layer
// - pointer to the data array
// - size of the data array
using LayersOutputFunc =
std::function<void(size_t, size_t, const std::string&, int, const float*, size_t)>;
using LayersOutputFunc = std::function<void(size_t, size_t, const std::string&,
int, const float*, size_t)>;
struct RuntimeConfig {
bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const {
@ -121,11 +122,11 @@ using KVCaches = hwy::Span<KVCache>;
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
hwy::ThreadPool& pool);
PerClusterPools& pools);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool);
PerClusterPools& pools);
~Gemma();
const ModelInfo& Info() const { return info_; }
@ -140,7 +141,7 @@ class Gemma {
const KVCaches& kv_caches, TimingInfo& timing_info);
private:
hwy::ThreadPool& pool_;
PerClusterPools& pools_;
GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header.
@ -155,14 +156,6 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt);
// DEPRECATED, call Gemma::Generate directly.
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
TimingInfo& timing_info) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_

View File

@ -27,6 +27,7 @@
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -75,9 +76,9 @@ std::string GetPrompt(std::istream& input, int verbosity,
}
// The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
const AcceptFunc& accept_token, std::string& eot_line) {
void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
int verbosity, const AcceptFunc& accept_token,
std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
@ -172,13 +173,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc");
hwy::ThreadPool pool(app.num_threads);
// For many-core, pinning workers to cores helps.
if (app.num_threads > 10) {
PinWorkersToCores(pool);
}
// Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores.
PerClusterPools pools(app.max_clusters, app.num_threads);
Gemma model = CreateGemma(loader, pool);
Gemma model = CreateGemma(loader, pools);
KVCache kv_cache =
KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
@ -205,11 +204,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(loader, inference, app);
ShowConfig(loader, inference, app, pools);
std::cout << "\n" << instructions << "\n";
}
ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(),
ReplGemma(model, kv_cache, inference, app.verbosity, AcceptFunc(),
app.eot_line);
}

View File

@ -23,20 +23,12 @@
#include <memory>
#include <string>
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/gemma.h" // For CreateGemma
#include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#if HWY_OS_LINUX
#include <sched.h>
#endif // HWY_OS_LINUX
#include "hwy/base.h" // HWY_IS_ASAN
namespace gcpp {
@ -58,86 +50,14 @@ static inline const char* CompiledConfig() {
}
}
static inline std::vector<size_t> LpsToCpus(
const hwy::LogicalProcessorSet& lps) {
std::vector<size_t> cpus;
cpus.reserve(lps.Count());
lps.Foreach([&cpus](size_t lp) { cpus.push_back(lp); });
return cpus;
}
static inline std::vector<size_t> AssignCpusFromTopology(
const hwy::Topology& topology, const size_t num_workers) {
// Assign CPUs to workers 0 to num_workers - 1 based on the topology.
// The assignments are done in a round-robin fashion across all clusters and
// Cores.
// For example, if we have 4 clusters, the assignments will be:
// Thread 0 -> Cluster 0, Core 0
// Thread 1 -> Cluster 1, Core 0
// Thread 2 -> Cluster 2, Core 0
// Thread 3 -> Cluster 3, Core 0
// Thread 4 -> Cluster 0, Core 1
// Thread 5 -> Cluster 1, Core 1
// ... and so on.
//
// This would result in the least amount of sharing of the last-level
// cache slices. All assignments are made from Package 0.
std::vector<std::vector<size_t>> clusters;
for (auto& package : topology.packages) {
for (auto& cluster : package.clusters) {
clusters.push_back(LpsToCpus(cluster.lps));
}
}
std::vector<size_t> assigned_cpus;
assigned_cpus.reserve(num_workers);
for (size_t i = 0; i < num_workers; ++i) {
size_t cluster_index = i % clusters.size();
size_t cpu_index = (i / clusters.size()) % clusters[cluster_index].size();
assigned_cpus.push_back(clusters[cluster_index][cpu_index]);
}
return assigned_cpus;
}
static inline void PinWorkersToCores(hwy::ThreadPool& pool) {
// Use topology to pin workers to cores if available.
hwy::Topology topology;
if (!topology.packages.empty()) {
std::vector<size_t> assigned_cpus =
AssignCpusFromTopology(topology, pool.NumWorkers());
pool.Run(0, pool.NumWorkers(),
[&assigned_cpus](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(assigned_cpus[thread]);
});
} else {
pool.Run(0, pool.NumWorkers(), [](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(thread);
});
}
}
class AppArgs : public ArgsBase<AppArgs> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = GetSupportedThreadCount();
}
}
public:
AppArgs(int argc, char* argv[]) {
InitAndParse(argc, argv);
ChooseNumThreads();
}
static inline size_t GetSupportedThreadCount() {
return HWY_MIN(hwy::ThreadPool::MaxThreads(), kMaxThreads);
}
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
Path log; // output
int verbosity;
size_t num_threads;
size_t num_threads; // divided among the detected clusters
size_t max_clusters;
std::string eot_line;
template <class Visitor>
@ -147,11 +67,10 @@ class AppArgs : public ArgsBase<AppArgs> {
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
2);
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads
"Number of threads to use.\n Default = Estimate of the "
"number of supported concurrent threads.",
2);
visitor(num_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of sockets/CCXs to use; default 0 = unlimited.", 2);
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "
@ -232,14 +151,14 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
};
static inline Gemma CreateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool);
PerClusterPools& pools) {
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
}
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
PerClusterPools& pools) {
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), pool);
loader.Info(), pools);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> {

201
util/threading.h Normal file
View File

@ -0,0 +1,201 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::sort
#include <memory>
#include <vector>
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
namespace gcpp {
// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an
// 'outer' thread pool with one worker per cluster.
//
// Useful for hierarchical parallelism, which makes sense when there are few
// but large tasks which should be parallelized by workers sharing a cache.
// This also implies lower latency for barrier synchronization of those workers.
class PerClusterPools {
using LPS = hwy::LogicalProcessorSet;
static inline std::vector<size_t> CoresInLPS(const LPS& cluster) {
std::vector<size_t> cores;
cores.reserve(cluster.Count());
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
return cores;
}
using CoreBitSets = std::vector<LPS>;
// Returns empty if detection failed.
CoreBitSets DetectCoresPerCluster() {
CoreBitSets clusters;
if (!have_threading_support_) return clusters;
// Which processors are not disabled via OS, taskset, or numactl.
LPS enabled;
// If we don't know, better to abort rather than risk oversubscribing.
if (!GetThreadAffinity(enabled)) return clusters;
hwy::Topology topology;
if (topology.packages.empty()) return clusters;
// Merge all clusters into one set, as a stopgap to emulate gemma-inl's
// prior single pool.
// TODO: remove once MatMul supports hierarchical parallelism.
LPS all;
// For each cluster, add its enabled *cores*.
for (const hwy::Topology::Package& package : topology.packages) {
for (const hwy::Topology::Cluster& cluster : package.clusters) {
cluster.lps.Foreach([&](size_t lp) {
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
all.Set(lp);
}
});
}
/* code to reinstate:
for (const hwy::Topology::Cluster& cluster : package.clusters) {
// Only use enabled *cores*, and only add if not empty.
cluster.lps.Foreach([&](size_t lp) {
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
all.Set(lp);
}
});
if (lps.Any()) clusters.push_back(lps);
}
*/
}
if (all.Any()) clusters.push_back(all);
// Sort by descending number of enabled cores, so that we preferentially
// use the largest clusters.
std::sort(clusters.begin(), clusters.end(),
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
return clusters;
}
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
outer_pool_.SetWaitMode(wait_mode);
for (auto& inner : inner_pools_) {
inner->SetWaitMode(wait_mode);
}
}
// The defaults for `AppArgs` `max_clusters` and `num_threads` are zero, which
// means no limit.
size_t CapIfNonzero(size_t detected, size_t user_max_or_zero) {
if (!have_threading_support_) return 0;
return (user_max_or_zero == 0) ? detected
: HWY_MIN(detected, user_max_or_zero);
}
public:
// PerClusterPools supports spin waits (see StartSpinning below). To prevent
// drastic slowdowns caused by excessive user-specified thread counts, which
// result in threads not running on their own core, we only allow for
// *upper bounds* on the number of clusters and threads. The actual number of
// clusters and threads are still limited by the detected topology.
PerClusterPools(size_t max_clusters, size_t max_threads)
: have_threading_support_(hwy::HaveThreadingSupport()),
cores_per_cluster_(DetectCoresPerCluster()),
outer_pool_(CapIfNonzero(cores_per_cluster_.size(), max_clusters)) {
// Topology detection failed - it currently requires Linux.
if (cores_per_cluster_.empty()) {
// Create a single inner pool with up to TotalLogicalProcessors() / 2
// workers, further limited by `max_threads` if nonzero, and then pin to
// the first N processors, which are typically on the first socket.
const size_t num_threads =
CapIfNonzero(hwy::TotalLogicalProcessors() / 2, max_threads);
fprintf(stderr, "CPU topology unknown, using %zu threads\n", num_threads);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
if (num_threads > 1) {
inner_pools_.back()->Run(0, num_threads,
[](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(thread);
});
}
return;
}
const size_t max_per_inner = max_threads / outer_pool_.NumWorkers();
for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) {
const size_t num_threads =
CapIfNonzero(cores_per_cluster_[outer].Count(), max_per_inner);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
}
// For each inner pool, pin their threads AND the associated outer thread
// (the one calling inner.Run()) to the enabled cores in the cluster.
outer_pool_.Run(
0, outer_pool_.NumWorkers(),
[this](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
hwy::ThreadPool& inner = *inner_pools_[outer];
const std::vector<size_t> cores =
CoresInLPS(cores_per_cluster_[outer]);
// May have been capped by max_threads.
HWY_ASSERT(inner.NumWorkers() <= cores.size());
inner.Run(0, inner.NumWorkers(),
[&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
}
// Spinning reduces the latency of barrier synchronization, but wastes lots of
// energy for long waits, so only do it during generation. This might also be
// unsafe in virtualized environments because we require threads to be running
// on their own core and thus responsive to the barrier synchronization.
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
// Bitset of cores, one per cluster, or empty if detection failed. Useful for
// displaying the topology.
const CoreBitSets& CoresPerCluster() const { return cores_per_cluster_; }
hwy::ThreadPool& Outer() { return outer_pool_; }
hwy::ThreadPool& Inner(size_t outer) {
HWY_ASSERT(outer < Outer().NumWorkers());
return *inner_pools_[outer];
}
private:
bool have_threading_support_;
CoreBitSets cores_per_cluster_;
hwy::ThreadPool outer_pool_;
// hwy::ThreadPool is unfortunately not marked as movable, so we have to use
// unique_ptr.
std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_