Use NestedPools, add NUMA infra

Improved threading.h, fix thread counts for single package/cluster systems
Temporarily forces to a single socket. Prefill 29.28 tps, decode 6.92.

Also fix benchmarks.cc build, update tensor allocator to Allocator

PiperOrigin-RevId: 687307167
This commit is contained in:
Jan Wassenberg 2024-10-18 08:10:44 -07:00 committed by Copybara-Service
parent c6384574db
commit 02ce1e344f
25 changed files with 864 additions and 678 deletions

View File

@ -29,10 +29,24 @@ cc_library(
) )
cc_library( cc_library(
name = "allocator", name = "threading",
hdrs = ["util/allocator.h"], hdrs = ["util/threading.h"],
deps = [ deps = [
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
],
)
cc_library(
name = "allocator",
srcs = ["util/allocator.cc"],
hdrs = ["util/allocator.h"],
deps = [
":basics",
":threading",
"@highway//:hwy",
"@highway//:thread_pool",
], ],
) )
@ -46,16 +60,6 @@ cc_library(
], ],
) )
cc_library(
name = "threading",
hdrs = ["util/threading.h"],
deps = [
"@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
],
)
cc_test( cc_test(
name = "threading_test", name = "threading_test",
srcs = ["util/threading_test.cc"], srcs = ["util/threading_test.cc"],
@ -168,6 +172,7 @@ cc_test(
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":allocator",
":ops", ":ops",
":threading", ":threading",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
@ -211,6 +216,7 @@ cc_library(
hdrs = ["gemma/weights.h"], hdrs = ["gemma/weights.h"],
deps = [ deps = [
":common", ":common",
"//compression:blob_store",
"//compression:compress", "//compression:compress",
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
@ -393,6 +399,12 @@ cc_binary(
], ],
) )
cc_library(
name = "benchmark_prompts",
hdrs = ["evals/prompts.h"],
deps = ["@highway//:hwy"],
)
cc_binary( cc_binary(
name = "benchmarks", name = "benchmarks",
srcs = [ srcs = [
@ -401,6 +413,7 @@ cc_binary(
], ],
deps = [ deps = [
":benchmark_helper", ":benchmark_helper",
":benchmark_prompts",
"@google_benchmark//:benchmark", "@google_benchmark//:benchmark",
"@highway//:hwy", # base.h "@highway//:hwy", # base.h
], ],

View File

@ -90,6 +90,7 @@ set(SOURCES
ops/sum-inl.h ops/sum-inl.h
paligemma/image.cc paligemma/image.cc
paligemma/image.h paligemma/image.h
util/allocator.cc
util/allocator.h util/allocator.h
util/app.h util/app.h
util/args.h util/args.h

View File

@ -39,8 +39,8 @@
namespace gcpp { namespace gcpp {
TEST(OptimizeTest, GradientDescent) { TEST(OptimizeTest, GradientDescent) {
PerClusterPools pools(1, 1); NestedPools pools(1);
hwy::ThreadPool& pool = pools.Inner(0); hwy::ThreadPool& pool = pools.Pool();
std::mt19937 gen(42); std::mt19937 gen(42);
const ModelInfo info = { const ModelInfo info = {

View File

@ -164,6 +164,7 @@ cc_library(
":io", ":io",
":nuq", ":nuq",
":sfp", ":sfp",
"//:allocator",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:profiler", "@highway//:profiler",

View File

@ -197,7 +197,6 @@ struct CompressTraits<BF16> {
size_t num, CompressPerThread& tls, size_t num, CompressPerThread& tls,
const PackedSpan<Packed>& packed, const PackedSpan<Packed>& packed,
const size_t packed_ofs) { const size_t packed_ofs) {
const hn::RebindToUnsigned<decltype(df)> du;
const hn::Repartition<BF16, decltype(df)> dbf; const hn::Repartition<BF16, decltype(df)> dbf;
const size_t NF = hn::Lanes(df); const size_t NF = hn::Lanes(df);

View File

@ -20,10 +20,9 @@
#define COMPRESS_STATS 0 #define COMPRESS_STATS 0
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <array>
#include <cstdio>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -35,70 +34,23 @@
#include "compression/io.h" #include "compression/io.h"
#include "compression/shared.h" #include "compression/shared.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "compression/distortion.h" #include "util/allocator.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // BF16
#include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS #if COMPRESS_STATS
#include "compression/distortion.h"
#include "hwy/stats.h" #include "hwy/stats.h"
#endif #endif
namespace gcpp { namespace gcpp {
// Compressed representation of floating-point elements. The array length may // Base class for rank-1 or 2 tensors (vector or matrix).
// differ from the number of elements. Associated operations such as Dot are // Supports both dynamic and compile-time sizing.
// implemented in SIMD code and are thus non-member functions. // Holds metadata and a non-owning pointer to the data, owned by the derived
template <typename Packed, size_t kCapacity> // MatStorageT class.
class CompressedArray { // This class also provides easy conversion from/to a table of contents for a
public: // BlobStore file, and a templated (compile-time) accessor for a 2-d array of
using value_type = Packed; // fixed inner dimension and type.
// Note that whenever you access data(), you have to consider a scale() that
// may be different from 1.0f.
Packed* data() { return data_.data(); }
const Packed* data() const { return data_.data(); }
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
// A scale of 0 indicates that the scale has likely never been set, so is
// "implicitly 1".
const Packed* data_scale1() const {
HWY_ASSERT(scale() == 1.f || scale() == 0.f);
return data_.data();
}
// Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range
// of magnitudes.
float scale() const { return scale_[0]; }
void set_scale(float scale) { scale_[0] = scale; }
constexpr size_t NumElements() const { return kCapacity; }
// Returns total number of packed elements for `BlobReader::Enqueue` and
// `Compress`. This differs from `NumElements` for `Packed=NuqStream`.
PackedSpan<Packed> GetSpan() { return MakeSpan(data(), data_.size()); }
PackedSpan<const Packed> GetSpan() const {
return MakeSpan(data(), data_.size());
}
private:
std::array<Packed, CompressedArrayElements<Packed>(kCapacity)> data_;
// Blobs are at least kBlobAlign bytes anyway.
float scale_[kBlobAlign / sizeof(float)];
};
// Yet another array class. This one is intended to be compatible with
// CompressedArray, but have both run-time sizing and compile-time constant
// size.
// It also provides easy conversion from/to a table of contents for a BlobStore
// file, and a templated (compile-time) accessor for a 2-d array of fixed inner
// dimension and type.
// The base class is intended for accessing the metadata, without needing to
// know any of the template arguments.
// It holds only a borrowed pointer to the data, but all metadata.
// It is designed to be put in a vector, and has default copy and operator=, so // It is designed to be put in a vector, and has default copy and operator=, so
// it is easy to read/write a blob_store file. // it is easy to read/write a blob_store file.
// The derived class or an external class owns the data.
class MatPtr { class MatPtr {
public: public:
// Full constructor for dynamic sizing. // Full constructor for dynamic sizing.
@ -111,12 +63,12 @@ class MatPtr {
rows_(rows), rows_(rows),
cols_(cols), cols_(cols),
ptr_(nullptr) {} ptr_(nullptr) {}
// Default constructor doesn't set anything. // Default is to leave all fields default-initialized.
MatPtr() = default; MatPtr() = default;
virtual ~MatPtr(); virtual ~MatPtr();
// Number of hwy::uint128_t in a TOC entry. // Number of hwy::uint128_t in a TOC entry.
// Note that the old-style BlobStore files Only have a list of keys and size. // Note that the old-style BlobStore files only have a list of keys and size.
// The new-style BlobStore files have an entry called "toc" that contains a // The new-style BlobStore files have an entry called "toc" that contains a
// vector of 4-tuples of // vector of 4-tuples of
// (name, type, (num_elements, element_size), (rows, cols)). // (name, type, (num_elements, element_size), (rows, cols)).
@ -144,6 +96,7 @@ class MatPtr {
} }
// Compatibility interface for CompressedArray. // Compatibility interface for CompressedArray.
// TODO: remove.
template <typename T> template <typename T>
T* data() { T* data() {
return HWY_RCAST_ALIGNED(T*, ptr_); return HWY_RCAST_ALIGNED(T*, ptr_);
@ -177,7 +130,6 @@ class MatPtr {
// Returns the number of bytes in the array. // Returns the number of bytes in the array.
size_t SizeBytes() const { return num_elements_ * element_size_; } size_t SizeBytes() const { return num_elements_ * element_size_; }
size_t CompressedSize() const { return SizeBytes(); }
// Returns the number of rows in the 2-d array (outer dimension). // Returns the number of rows in the 2-d array (outer dimension).
size_t Rows() const { return rows_; } size_t Rows() const { return rows_; }
@ -211,8 +163,8 @@ class MatPtr {
} }
// Calls func on the upcasted type. Since MatPtr by design is not templated, // Calls func on the upcasted type. Since MatPtr by design is not templated,
// here we provide a way to get to the derived type, provided that the type // here we provide a way to get to the derived type, provided that `Type()`
// matches one of a known short-list. // is one of the strings returned by `TypeName()`.
template <class FuncT, typename... TArgs> template <class FuncT, typename... TArgs>
decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args); decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args);
@ -243,8 +195,6 @@ class MatPtr {
template <typename MatT> template <typename MatT>
class MatPtrT : public MatPtr { class MatPtrT : public MatPtr {
public: public:
using value_type = MatT;
// Full constructor for dynamic sizing. // Full constructor for dynamic sizing.
MatPtrT(const std::string& name, size_t rows, size_t cols) MatPtrT(const std::string& name, size_t rows, size_t cols)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {} : MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
@ -276,20 +226,13 @@ class MatPtrT : public MatPtr {
} }
return name; return name;
} }
// Sets the number of elements in the array. For use when the number of // Sets the number of elements in the array. For use when the number of
// elements is != rows * cols ONLY. // elements is != rows * cols ONLY.
void SetNumElements(size_t num_elements) { void SetNumElements(size_t num_elements) {
num_elements_ = CompressedArrayElements<MatT>(num_elements); num_elements_ = CompressedArrayElements<MatT>(num_elements);
} }
// Fast 2-d accessor for a 2-d array of fixed inner dimension and type.
template <typename T = MatT, size_t kInner>
const T& AtT(size_t row, size_t col) const {
size_t index = row * kInner + col;
HWY_DASSERT(index < num_elements_);
return HWY_RCAST_ALIGNED(const T*, ptr_)[index];
}
// 2-d Accessor for a specific type but with a dynamic inner dimension. // 2-d Accessor for a specific type but with a dynamic inner dimension.
template <typename T = MatT> template <typename T = MatT>
const T& At(size_t row, size_t col) const { const T& At(size_t row, size_t col) const {
@ -299,17 +242,15 @@ class MatPtrT : public MatPtr {
} }
// 1-d Accessor for a specific type. // 1-d Accessor for a specific type.
template <typename T = MatT> // TODO: replace this with a Foreach(), or at least a ForEachRow().
const T& At(size_t index) const { const MatT& At(size_t index) const {
HWY_DASSERT(index < num_elements_); HWY_DASSERT(index < num_elements_);
return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index];
}
template <typename T = MatT>
T& At(size_t index) {
return HWY_RCAST_ALIGNED(T*, ptr_)[index];
} }
MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; }
// Compatibility interface for CompressedArray. // Compatibility interface for CompressedArray.
// TODO: remove
template <typename T = MatT> template <typename T = MatT>
T* data() { T* data() {
return HWY_RCAST_ALIGNED(T*, ptr_); return HWY_RCAST_ALIGNED(T*, ptr_);
@ -353,15 +294,14 @@ class MatStorageT : public MatPtrT<MatT> {
public: public:
// Full constructor for dynamic sizing. // Full constructor for dynamic sizing.
MatStorageT(const std::string& name, size_t rows, size_t cols) MatStorageT(const std::string& name, size_t rows, size_t cols)
: MatPtrT<MatT>(name, rows, cols), : MatPtrT<MatT>(name, rows, cols) {
data_(hwy::AllocateAligned<MatT>( Allocate();
hwy::DivCeil(this->SizeBytes(), sizeof(MatT)))) {
this->ptr_ = data_.get();
} }
// Can copy the metadata, from a MatPtr, and allocate later. // Can copy the metadata, from a MatPtr, and allocate later.
MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {} MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {}
~MatStorageT() = default;
// No copying of MatStorageT as it contains big data. // Move-only because this contains a unique_ptr.
MatStorageT(const MatStorageT& other) = delete; MatStorageT(const MatStorageT& other) = delete;
MatStorageT& operator=(const MatStorageT& other) = delete; MatStorageT& operator=(const MatStorageT& other) = delete;
MatStorageT(MatStorageT&& other) = default; MatStorageT(MatStorageT&& other) = default;
@ -377,7 +317,7 @@ class MatStorageT : public MatPtrT<MatT> {
} else { } else {
this->num_elements_ = num_elements; this->num_elements_ = num_elements;
} }
data_ = hwy::AllocateAligned<MatT>(num_elements); data_ = Allocator::Alloc<MatT>(num_elements);
this->ptr_ = data_.get(); this->ptr_ = data_.get();
} }
@ -388,8 +328,6 @@ class MatStorageT : public MatPtrT<MatT> {
} }
private: private:
// Aligned data array.
// std::unique_ptr<MatT[]> data_;
hwy::AlignedFreeUniquePtr<MatT[]> data_; hwy::AlignedFreeUniquePtr<MatT[]> data_;
}; };
@ -507,7 +445,7 @@ class CompressStats {
}; };
#else #else
struct CompressStats { struct CompressStats {
void Notify(const DistortionStats&) {} void Notify(...) {}
void NotifyIn(int) {} void NotifyIn(int) {}
void Assimilate(const CompressStats&) {} void Assimilate(const CompressStats&) {}
void PrintAll() {} void PrintAll() {}
@ -526,18 +464,17 @@ struct CompressWorkingSet {
// Functor called for each tensor, which loads them and their scaling factors // Functor called for each tensor, which loads them and their scaling factors
// from BlobStore. // from BlobStore.
class CacheLoader { class ReadFromBlobStore {
public: public:
explicit CacheLoader(const Path& blob_filename) { explicit ReadFromBlobStore(const Path& blob_filename) {
err_ = reader_.Open(blob_filename); err_ = reader_.Open(blob_filename);
if (err_ != 0) { if (HWY_UNLIKELY(err_ != 0)) {
fprintf(stderr, fprintf(stderr, "Error %d opening BlobStore %s.\n", err_,
"Cached compressed weights does not exist yet (code %d), " blob_filename.path.c_str());
"loading from file: %s.\n", return; // avoid overwriting err_ to ensure ReadAll will fail.
err_, blob_filename.path.c_str());
} }
err_ = file_toc_.LoadToc(reader_); err_ = file_toc_.LoadToc(reader_);
if (err_ != 0) { if (HWY_UNLIKELY(err_ != 0)) {
fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_); fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_);
} }
} }

View File

@ -36,10 +36,8 @@
#include "util/args.h" #include "util/args.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/per_target.h" #include "hwy/per_target.h" // VectorBytes
#include "hwy/timer.h" #include "hwy/timer.h"
namespace gcpp { namespace gcpp {
@ -57,7 +55,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app) const AppArgs& app)
: pools_(app.max_clusters, app.max_threads, app.pin) { : pools_(CreatePools(app)) {
InferenceArgs mutable_inference = inference; InferenceArgs mutable_inference = inference;
AbortIfInvalidArgs(mutable_inference); AbortIfInvalidArgs(mutable_inference);
LoaderArgs mutable_loader = loader; LoaderArgs mutable_loader = loader;
@ -217,7 +215,7 @@ void LogSpeedStats(double time_start, size_t total_tokens) {
} }
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
PerClusterPools& pools) { NestedPools& pools) {
loader.Print(app.verbosity); loader.Print(app.verbosity);
inference.Print(app.verbosity); inference.Print(app.verbosity);
app.Print(app.verbosity); app.Print(app.verbosity);
@ -228,21 +226,15 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
char cpu100[100] = "unknown"; char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100); (void)hwy::platform::GetCpuString(cpu100);
// TODO: call TopologyString() once we have NestedPools.
const std::vector<hwy::LogicalProcessorSet>& clusters =
pools.CoresPerCluster();
const size_t per_cluster =
clusters.empty() ? 0 : pools.CoresPerCluster().front().Count();
fprintf(stderr, fprintf(stderr,
"Date & Time : %s" // dt includes \n "Date & Time : %s" // dt includes \n
"CPU : %s\n" "CPU : %s\n"
"CPU topology : %zux%zu, using %zux%zu\n" "CPU topology : %s\n"
"Instruction set : %s (%zu bits)\n" "Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n" "Compiled config : %s\n"
"Weight Type : %s\n" "Weight Type : %s\n"
"EmbedderInput Type : %s\n", "EmbedderInput Type : %s\n",
dt, cpu100, pools.CoresPerCluster().size(), per_cluster, dt, cpu100, pools.TopologyString(),
pools.Outer().NumWorkers(), pools.Inner(0).NumWorkers(),
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8, hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
CompiledConfig(), StringFromType(loader.Info().weight), CompiledConfig(), StringFromType(loader.Info().weight),
TypeName<EmbedderInputT>()); TypeName<EmbedderInputT>());

View File

@ -106,7 +106,7 @@ class GemmaEnv {
private: private:
// Thread pool for running inference. // Thread pool for running inference.
PerClusterPools pools_; NestedPools pools_;
// Random number generator. // Random number generator.
std::mt19937 gen_; std::mt19937 gen_;
// The model to run inference on. // The model to run inference on.
@ -121,7 +121,7 @@ class GemmaEnv {
void LogSpeedStats(double time_start, size_t total_tokens); void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
PerClusterPools& pools); NestedPools& pools);
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
} // namespace gcpp } // namespace gcpp

View File

@ -54,7 +54,7 @@ int main(int argc, char** argv) {
} }
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); gcpp::NestedPools pools = gcpp::CreatePools(app);
gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache = gcpp::KVCache kv_cache =
gcpp::KVCache::Create(model.GetModelConfig(), gcpp::KVCache::Create(model.GetModelConfig(),

View File

@ -94,7 +94,7 @@ struct Activations {
return inv_timescale; return inv_timescale;
} }
void Allocate(size_t batch_size, PerClusterPools& pools) { void Allocate(size_t batch_size, NestedPools& pools) {
post_qk = layer_config.post_qk; post_qk = layer_config.post_qk;
const size_t model_dim = weights_config.model_dim; const size_t model_dim = weights_config.model_dim;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim;

View File

@ -1131,7 +1131,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(), image_patches[i].get(),
weights.vit_img_embedding_bias.data_scale1(), weights.vit_img_embedding_bias.data_scale1(),
activations.x.Batch(i), activations.env.Pools().Outer()); activations.x.Batch(i), activations.env.Pool());
} }
// Add position embeddings. // Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(), AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
@ -1416,7 +1416,7 @@ template <typename T>
void GenerateSingleT(const ModelWeightsStorage& model, void GenerateSingleT(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end, const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, PerClusterPools& pools, KVCache& kv_cache, NestedPools& pools,
TimingInfo& timing_info) { TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1; constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0; const size_t qbatch_start = 0;
@ -1440,7 +1440,7 @@ void GenerateBatchT(const ModelWeightsStorage& model,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, PerClusterPools& pools, const KVCaches& kv_caches, NestedPools& pools,
TimingInfo& timing_info) { TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(queries_pos.size() == num_queries);
@ -1477,7 +1477,7 @@ template <typename T>
void GenerateImageTokensT(const ModelWeightsStorage& model, void GenerateImageTokensT(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens, const Image& image, ImageTokens& image_tokens,
PerClusterPools& pools) { NestedPools& pools) {
if (model.Config().vit_layer_configs.empty()) { if (model.Config().vit_layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens."); HWY_ABORT("Model does not support generating image tokens.");
} }
@ -1500,7 +1500,7 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
void GenerateSingle( // NOLINT(misc-definitions-in-headers) void GenerateSingle( // NOLINT(misc-definitions-in-headers)
GEMMA_TYPE, const ModelWeightsStorage& model, GEMMA_TYPE, const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
size_t prefix_end, KVCache& kv_cache, PerClusterPools& pools, size_t prefix_end, KVCache& kv_cache, NestedPools& pools,
TimingInfo& timing_info) { TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_TYPE>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_TYPE>)
(model, runtime_config, prompt, pos, prefix_end, kv_cache, pools, (model, runtime_config, prompt, pos, prefix_end, kv_cache, pools,
@ -1512,7 +1512,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers)
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
PerClusterPools& pools, TimingInfo& timing_info) { NestedPools& pools, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_TYPE>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_TYPE>)
(model, runtime_config, queries_prompt, queries_pos, queries_prefix_end, (model, runtime_config, queries_prompt, queries_pos, queries_prefix_end,
kv_caches, pools, timing_info); kv_caches, pools, timing_info);
@ -1521,7 +1521,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers)
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
GEMMA_TYPE, const ModelWeightsStorage& model, GEMMA_TYPE, const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const Image& image, const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, PerClusterPools& pools) { ImageTokens& image_tokens, NestedPools& pools) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)
(model, runtime_config, image, image_tokens, pools); (model, runtime_config, image, image_tokens, pools);
} }

View File

@ -27,6 +27,7 @@
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "compression/shared.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
@ -39,16 +40,16 @@
namespace gcpp { namespace gcpp {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, PerClusterPools& pools) const ModelInfo& info, NestedPools& pools)
: pools_(pools), tokenizer_(tokenizer_path), info_(info) { : pools_(pools), tokenizer_(tokenizer_path), info_(info) {
model_.Load(weights, info.model, info.weight, pools_.Inner(0)); model_.Load(weights, info.model, info.weight, pools_.Pool());
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
PerClusterPools& pools) NestedPools& pools)
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) { : pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
HWY_ASSERT(info.weight == Type::kF32); HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, pools_.Inner(0)); model_.Allocate(info.model, info.weight, pools_.Pool());
} }
Gemma::~Gemma() { Gemma::~Gemma() {
@ -63,17 +64,16 @@ Gemma::~Gemma() {
const RuntimeConfig& runtime_config, \ const RuntimeConfig& runtime_config, \
const PromptTokens& prompt, size_t pos, \ const PromptTokens& prompt, size_t pos, \
size_t prefix_end, KVCache& kv_cache, \ size_t prefix_end, KVCache& kv_cache, \
PerClusterPools& pools, TimingInfo& timing_info); \ NestedPools& pools, TimingInfo& timing_info); \
extern void GenerateBatch( \ extern void GenerateBatch( \
TWEIGHT, const ModelWeightsStorage& model, \ TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \ const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
const KVCaches& kv_caches, PerClusterPools& pools, \ const KVCaches& kv_caches, NestedPools& pools, TimingInfo& timing_info); \
TimingInfo& timing_info); \
extern void GenerateImageTokens( \ extern void GenerateImageTokens( \
TWEIGHT, const ModelWeightsStorage& model, \ TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, const Image& image, \ const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, PerClusterPools& pools); ImageTokens& image_tokens, NestedPools& pools);
GEMMA_DECLARE(float) GEMMA_DECLARE(float)
GEMMA_DECLARE(BF16) GEMMA_DECLARE(BF16)
GEMMA_DECLARE(NuqStream) GEMMA_DECLARE(NuqStream)
@ -85,7 +85,7 @@ struct GenerateSingleT {
void operator()(const ModelWeightsStorage& model, void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end, const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, PerClusterPools& pools, KVCache& kv_cache, NestedPools& pools,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end, GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
kv_cache, pools, timing_info); kv_cache, pools, timing_info);
@ -99,7 +99,7 @@ struct GenerateBatchT {
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, PerClusterPools& pools, const KVCaches& kv_caches, NestedPools& pools,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos, GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
queries_prefix_end, kv_caches, pools, timing_info); queries_prefix_end, kv_caches, pools, timing_info);
@ -110,7 +110,7 @@ template <class TConfig>
struct GenerateImageTokensT { struct GenerateImageTokensT {
void operator()(const ModelWeightsStorage& model, void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const Image& image, const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, PerClusterPools& pools) const { ImageTokens& image_tokens, NestedPools& pools) const {
GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens, GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
pools); pools);
} }

View File

@ -183,11 +183,10 @@ struct TimingInfo {
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
PerClusterPools& pools); NestedPools& pools);
// Allocates weights, caller is responsible for filling them. // Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools);
PerClusterPools& pools);
~Gemma(); ~Gemma();
const ModelConfig& GetModelConfig() const { return model_.Config(); } const ModelConfig& GetModelConfig() const { return model_.Config(); }
@ -229,7 +228,7 @@ class Gemma {
const Image& image, ImageTokens& image_tokens); const Image& image, ImageTokens& image_tokens);
private: private:
PerClusterPools& pools_; NestedPools& pools_;
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header. // Type-erased so that this can be defined in the header.

View File

@ -188,9 +188,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args,
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PROFILER_ZONE("Run.misc"); PROFILER_ZONE("Run.misc");
// TODO: remove once MatMul is updated.
app.max_packages = 1;
// Note that num_threads is an upper bound; we also limit to the number of // Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores. // detected and enabled cores.
PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); NestedPools pools = CreatePools(app);
Allocator::Init(pools.Topology());
Gemma model = CreateGemma(loader, pools); Gemma model = CreateGemma(loader, pools);
KVCache kv_cache = KVCache kv_cache =

View File

@ -21,6 +21,7 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "compression/blob_store.h"
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
@ -36,7 +37,7 @@ namespace gcpp {
template <typename T> template <typename T>
struct TensorLoader { struct TensorLoader {
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet, void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
CacheLoader& loader) { ReadFromBlobStore& loader) {
weights.ForEachTensor( weights.ForEachTensor(
{&weights}, fet, {&weights}, fet,
[&loader](const char* name, hwy::Span<MatPtr*> tensors) { [&loader](const char* name, hwy::Span<MatPtr*> tensors) {
@ -52,7 +53,7 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
HWY_ABORT("The model weights file '%s' does not exist.", HWY_ABORT("The model weights file '%s' does not exist.",
weights.path.c_str()); weights.path.c_str());
} }
CacheLoader loader(weights); ReadFromBlobStore loader(weights);
ForEachType fet = ForEachType fet =
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc; loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
if (fet == ForEachType::kLoadWithToc) { if (fet == ForEachType::kLoadWithToc) {

View File

@ -15,8 +15,6 @@
#include <stddef.h> #include <stddef.h>
#include <array>
#include "compression/compress.h" #include "compression/compress.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -376,15 +374,6 @@ HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) {
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num); return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
} }
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <typename MatT, size_t kCapacity, typename VT>
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
const VT* vec_aligned, size_t num) {
const hn::ScalableTag<VT> d;
return w.scale() *
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
}
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <typename MatT, typename VT> template <typename MatT, typename VT>
HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs, HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs,

View File

@ -470,7 +470,7 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
env.Pool().Run( env.Pool().Run(
0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR {
// TODO: when using PerClusterPool, compute lp from outer and inner. // TODO: when using PerClusterPool, compute lp from outer and inner.
float* HWY_RESTRICT buf = env.Buf(thread); float* HWY_RESTRICT buf = env.Buf().Batch(thread);
const size_t tx = idx_tile % tilesX; const size_t tx = idx_tile % tilesX;
const size_t ty = idx_tile / tilesX; const size_t ty = idx_tile / tilesX;
const size_t row_ac = ty * kRegRows; const size_t row_ac = ty * kRegRows;

View File

@ -18,12 +18,15 @@
#include <stddef.h> #include <stddef.h>
#include "util/allocator.h" // RowVectorBatch // IWYU pragma: begin_exports
#include "util/threading.h" #include "util/threading.h"
#include "hwy/aligned_allocator.h" // IWYU pragma: export #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: export #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/per_target.h" // IWYU pragma: end_exports
#include "util/allocator.h" // RowVectorBatch
#include "hwy/per_target.h" // VectorBytes
namespace gcpp { namespace gcpp {
@ -80,19 +83,18 @@ Mat<const T> ConstMat(const T* HWY_RESTRICT ptr, size_t cols) {
class MatMulEnv { class MatMulEnv {
public: public:
MatMulEnv() : pools_(nullptr) {} MatMulEnv() : pools_(nullptr) {}
explicit MatMulEnv(PerClusterPools& pools) : pools_(&pools) { explicit MatMulEnv(NestedPools& pools) : pools_(&pools) {
const size_t num_lp = pools.NumLP(); const size_t N = hwy::VectorBytes() / sizeof(float);
const size_t NF = hwy::VectorBytes() / sizeof(float); buf_ = RowVectorBatch<float>(pools.MaxWorkers(), 16 * N);
buf_ = RowVectorBatch<float>(num_lp, 16 * NF);
} }
float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); } RowVectorBatch<float>& Buf() { return buf_; }
PerClusterPools& Pools() const { return *pools_; } NestedPools& Pools() const { return *pools_; }
hwy::ThreadPool& Pool() const { return pools_->Inner(0); } hwy::ThreadPool& Pool() const { return pools_->Pool(); }
private: private:
RowVectorBatch<float> buf_; RowVectorBatch<float> buf_;
PerClusterPools* pools_; NestedPools* pools_;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -14,9 +14,14 @@
// limitations under the License. // limitations under the License.
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32. // Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support.
#if HWY_ARCH_ARM_V7
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
#else
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
#endif
#include "ops/matmul.h" #include "ops/matmul.h"
@ -26,8 +31,8 @@
#include <memory> #include <memory>
#include "compression/compress.h" #include "compression/compress.h"
#include "util/allocator.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h" #include "hwy/timer.h"
@ -165,7 +170,7 @@ template <typename MatTA, typename MatTB>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_trans, const float scale, const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add_row, MatMulEnv& env,
float* HWY_RESTRICT out) { float* HWY_RESTRICT out) {
// MatTA can be any Packed except NuqStream because it uses pointer // MatTA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not // arithmetic, because it is the second argument to Dot, which does not
@ -176,24 +181,22 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const PackedSpan<const MatTB> b_span = const PackedSpan<const MatTB> b_span =
MakeSpan(b_trans, cols_a_rows_b * cols_bc); MakeSpan(b_trans, cols_a_rows_b * cols_bc);
env.Pools().Outer().Run( StaticPartitionRowsAndCols(
0, rows_ac, [&](const uint64_t i, size_t o_thread) HWY_ATTR { env.Pools(), rows_ac, cols_bc, sizeof(MatTB),
hwy::ThreadPool& inner = env.Pools().Inner(o_thread); [&](size_t /*node*/, hwy::ThreadPool& pool,
if (add != nullptr) { const size_t /*worker_offset*/, const size_t row_begin,
inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) { const size_t row_end, const size_t col_begin, const size_t col_end) {
out[i * cols_bc + j] = pool.Run(row_begin, row_end,
scale * Dot(df, b_span, j * cols_a_rows_b, [&](const uint64_t row, size_t /*thread*/) {
a + i * cols_a_rows_b, cols_a_rows_b) + for (size_t col = col_begin; col < col_end; ++col) {
add[j]; const float add = add_row ? add_row[col] : 0.0f;
}); out[row * cols_bc + col] =
} else { scale * Dot(df, b_span, col * cols_a_rows_b,
inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) { a + row * cols_a_rows_b, cols_a_rows_b) +
out[i * cols_bc + j] = add;
scale * Dot(df, b_span, j * cols_a_rows_b,
a + i * cols_a_rows_b, cols_a_rows_b);
});
} }
}); });
});
} }
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
@ -261,9 +264,10 @@ void TestAllMatMul() {
return; return;
} }
PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1); NestedPools pools(4, /*pin=*/1);
MatMulEnv env(pools);
pools.StartSpinning(); pools.StartSpinning();
Allocator::Init(pools.Topology());
MatMulEnv env(pools);
using F32 = float; using F32 = float;
using SFP = SfpStream; using SFP = SfpStream;

View File

@ -45,6 +45,15 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <class ArrayT, typename VT>
HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned,
size_t num) {
const hn::ScalableTag<VT> d;
return w.scale() * Dot(d, MakeConstSpan(w.data(), w.NumElements()), w_ofs,
vec_aligned, num);
}
// Simple version without tiling nor threading, but two offsets/outputs and // Simple version without tiling nor threading, but two offsets/outputs and
// always with addition. // always with addition.
template <typename ArrayT, typename VecT, typename AddT> template <typename ArrayT, typename VecT, typename AddT>

183
util/allocator.cc Normal file
View File

@ -0,0 +1,183 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/allocator.h"
#include <stdio.h>
#include <vector>
#include "util/basics.h" // MaybeCheckInitialized
#if GEMMA_NUMA
#if HWY_OS_WIN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
#elif HWY_OS_LINUX
#include <sys/syscall.h>
#include <cerrno>
#endif // HWY_OS_*
#endif // GEMMA_NUMA
namespace gcpp {
/*static*/ size_t Allocator::bytes_per_page_;
/*static*/ bool Allocator::use_numa_;
/*static*/ size_t Allocator::alignment_;
/*static*/ size_t Allocator::DetectPageSize() {
#if HWY_OS_WIN
SYSTEM_INFO sys_info;
GetSystemInfo(&sys_info);
return sys_info.dwPageSize;
#elif HWY_OS_LINUX
return sysconf(_SC_PAGESIZE);
#else
return 0;
#endif
}
#if GEMMA_NUMA && HWY_OS_LINUX
using Ret = long; // NOLINT(runtime/int)
using UL = unsigned long; // NOLINT(runtime/int)
static constexpr size_t ULBits = sizeof(UL) * 8;
// Calling via syscall avoids a dependency on libnuma.
struct SyscallWrappers {
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
unsigned flags) {
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
};
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
int* status, int flags) {
MaybeCheckInitialized(pages, count * sizeof(void*));
MaybeCheckInitialized(nodes, count * sizeof(int));
MaybeCheckInitialized(status, count * sizeof(int));
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
}
};
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
const int* status) {
// Return value 0 does not actually guarantee all pages were moved.
size_t num_busy = 0;
for (size_t i = 0; i < num_pages; ++i) {
if (status[i] == -EBUSY) {
++num_busy;
// Touch
hwy::ZeroBytes(pages[i], 8);
} else if (status[i] != static_cast<int>(node)) {
fprintf(stderr, "Error %d moving pages[%zu]=%p to node %zu (errno %d)\n",
status[i], i, pages[i], node, errno);
}
}
return num_busy;
}
// Attempts to move(!) memory to the given NUMA node, typically obtained from
// `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. Using `mbind`
// directly is easier than calling libnuma's `numa_move_pages`, which requires
// an array of pages. Note that `numa_tonode_memory` is insufficient because
// it does not specify the `MPOL_MF_MOVE` flag, so it only sets the policy,
// which means it would have to be called before pages are faulted in, but
// `aligned_allocator.h` modifies the first bytes for its bookkeeping.
// May overwrite some of the memory with zeros.
static void BindMemory(void* ptr, size_t bytes, size_t node) {
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
// Avoid mbind because it does not report why it failed, which is most likely
// because pages are busy, in which case we want to know which.
#if 0
// nodemask with only the given node set.
UL nodes[hwy::DivCeil(kMaxNodes, ULBits)] = {};
nodes[node / ULBits] = 1ULL << (node % ULBits);
const int mode = 2; // MPOL_BIND
const unsigned flags = 3; // MPOL_MF_MOVE | MPOL_MF_STRICT
const int ret =
SyscallWrappers::mbind(ptr, bytes, mode, nodes, kMaxNodes, flags);
if (ret != 0) {
fprintf(stderr, "Failed to bind %p %zu to node %zu (errno %d)\n", ptr,
bytes, node, errno);
}
#elif 1
const unsigned flags = 2; // MPOL_MF_MOVE
const size_t bytes_per_page = static_cast<size_t>(sysconf(_SC_PAGESIZE));
HWY_ASSERT(bytes % bytes_per_page == 0);
const size_t num_pages = bytes / bytes_per_page;
std::vector<void*> pages;
pages.reserve(num_pages);
for (size_t i = 0; i < num_pages; ++i) {
pages.push_back(static_cast<uint8_t*>(ptr) + i * bytes_per_page);
}
std::vector<int> nodes(num_pages, node);
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
Ret ret = SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
size_t num_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (num_busy != 0) {
// Try again
ret = SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
const size_t num_busy_before = num_busy;
num_busy = CountBusyPages(num_pages, node, pages.data(), status.data());
fprintf(
stderr,
"second try still %zu busy, was %zu. 2nd ret %d status %d %d %d %d\n",
num_busy, num_busy_before, static_cast<int>(ret), status[0], status[1],
status[2], status[3]);
}
if (ret < 0) {
fprintf(stderr,
"Failed to bind %p %zu to node %zu (errno %d) status %d %d\n", ptr,
bytes, node, errno, status[0], status[1]);
}
#endif
}
#else
// TODO: support other OSes.
static void BindMemory(void*, size_t, size_t) {}
#endif // GEMMA_NUMA && HWY_OS_LINUX
void BindTensor(NestedPools& nested, size_t rows, size_t cols,
size_t bytes_per_col, void* ptr) {
if (!Allocator::UseNUMA()) return;
uint8_t* p8 = static_cast<uint8_t*>(ptr);
const size_t bytes_per_row = cols * bytes_per_col;
StaticPartitionRowsAndCols(
nested, rows, cols, bytes_per_col,
[&](size_t node, hwy::ThreadPool&, const size_t /*worker_offset*/,
const size_t row_begin, const size_t row_end, const size_t col_begin,
const size_t col_end) {
for (size_t row = row_begin; row < row_end; ++row) {
uint8_t* slice = p8 + row * bytes_per_row + col_begin * bytes_per_col;
const size_t slice_size = (col_end - col_begin) * bytes_per_col;
BindMemory(slice, slice_size, node);
}
});
}
} // namespace gcpp

View File

@ -19,8 +19,29 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "hwy/aligned_allocator.h" // IWYU pragma: export #include <cstdlib> // std::aligned_alloc
// IWYU pragma: begin_exports
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// IWYU pragma: end_exports
#ifndef GEMMA_NUMA
// The check below requires two #if, hence start with 0 and redefine to 1.
#define GEMMA_NUMA 0
// To avoid a dependency on libnuma, use syscalls directly. We require six
// arguments, which has been supported by glibc since around 2010.
#if defined(__GLIBC__) && defined(__GLIBC_PREREQ)
#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11)
#undef GEMMA_NUMA
#define GEMMA_NUMA 1
#endif
#endif
#endif // GEMMA_NUMA
namespace gcpp { namespace gcpp {
@ -74,6 +95,136 @@ class RowVectorBatch {
size_t len_; // columns in the matrix = vector length size_t len_; // columns in the matrix = vector length
}; };
// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for
// convenience - avoids passing around a reference.
class Allocator {
public:
static void Init(const BoundedTopology& topology) {
bytes_per_page_ = DetectPageSize();
HWY_ASSERT(bytes_per_page_ <= (4 << 20));
// NUMA only makes sense if:
// - the page size is known and 'reasonably small', preferably less than
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
// - we successfully detected topology and there are multiple nodes;
// - there are multiple packages, because we shard by package_idx.
use_numa_ = (bytes_per_page_ != 0 && bytes_per_page_ <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1;
// TODO: remove once tensors are page-aligned.
use_numa_ = false;
fprintf(stderr, "Warning: disabling use_numa_\n");
alignment_ = use_numa_ ? bytes_per_page_ : HWY_ALIGNMENT;
}
static bool UseNUMA() { return use_numa_; }
// BindTensor requires row pointers and lengths be a multiple of this.
static size_t Alignment() { return alignment_; }
template <typename T>
static hwy::AlignedFreeUniquePtr<T[]> Alloc(size_t num) {
// For non-NUMA, use the Highway allocator because it defends against 2k
// aliasing.
if (!use_numa_) return hwy::AllocateAligned<T>(num);
constexpr size_t kSize = sizeof(T);
// Ensure the `bytes = num * kSize` computation did not overflow.
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
if (check != num) {
return hwy::AlignedFreeUniquePtr<T[]>(); // overflowed
}
// AlignedFreeUniquePtr has a deleter that can call an arbitrary `free`, but
// with an extra opaque pointer, which we discard via this adapter.
const auto call_free = [](void* ptr, void*) { std::free(ptr); };
#if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28
T* p = static_cast<T*>(std::aligned_alloc(Alignment(), bytes));
#else
void* mem = nullptr;
int err = posix_memalign(&mem, Alignment(), bytes);
HWY_ASSERT(err == 0);
T* p = static_cast<T*>(mem);
#endif
return hwy::AlignedFreeUniquePtr<T[]>(
p, hwy::AlignedFreer(call_free, nullptr));
}
private:
static size_t DetectPageSize();
// Required for BindMemory. Usually 4K, but can differ on Arm.
static size_t bytes_per_page_;
static bool use_numa_;
static size_t alignment_;
};
// Used in MatMul and allocator.h. Defined here because it depends on
// Allocator::Alignment().
template <class Func>
void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols,
size_t bytes_per_element, const Func& func) {
// Both rows and cols must be a multiple of the alignment to avoid
// touching remote pages.
const size_t multiple = Allocator::Alignment() / bytes_per_element;
// Static partitioning of columns across packages. We assume that column
// sharding is more expensive, hence we distribute columns across packages,
// of which there are usually only one or two. For MatMul, the final result is
// the sum of each package's partial dot products.
hwy::ThreadPool& all_packages = nested.AllPackages();
const size_t num_packages = all_packages.NumWorkers();
const size_t cols_per_package =
hwy::RoundUpTo(hwy::DivCeil(cols, num_packages), multiple);
const size_t col_tasks = hwy::DivCeil(cols, cols_per_package);
HWY_ASSERT(col_tasks <= num_packages);
all_packages.Run(
0, col_tasks, [&](uint64_t package_idx, size_t package_thread) {
HWY_ASSERT(package_idx == package_thread); // one task per worker
const size_t col_begin = package_idx * cols_per_package;
const size_t col_end = HWY_MIN(col_begin + cols_per_package, cols);
// Static partitioning of rows across the package's clusters. We assume
// that row sharding is cheaper. In MatMul, results can indeed be
// computed independently for each row of B.
hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx);
const size_t num_clusters = all_clusters.NumWorkers();
const size_t rows_per_cluster =
hwy::RoundUpTo(hwy::DivCeil(rows, num_clusters), multiple);
const size_t row_tasks = hwy::DivCeil(rows, rows_per_cluster);
HWY_ASSERT(row_tasks <= num_clusters);
all_clusters.Run(
0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) {
HWY_ASSERT(cluster_idx == cluster_thread); // one task per worker
// For binding to NUMA node.
const size_t node = nested.Node(package_idx, cluster_idx);
// Older CPUs that predate chiplets typically have only one
// cluster, so callers should also parallelize using this
// per-cluster pool.
hwy::ThreadPool& cluster =
nested.Cluster(package_idx, cluster_idx);
// This plus the worker from `cluster->Run` is the TLS index.
const size_t worker_offset =
nested.WorkerOffset(package_idx, cluster_idx);
const size_t row_begin = cluster_idx * rows_per_cluster;
const size_t row_end =
HWY_MIN(row_begin + rows_per_cluster, rows);
func(node, cluster, worker_offset, row_begin, row_end, col_begin,
col_end);
});
});
}
void BindTensor(NestedPools& nested, size_t rows, size_t cols,
size_t bytes_per_col, void* ptr);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_

View File

@ -194,13 +194,12 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
ModelInfo info_; ModelInfo info_;
}; };
static inline Gemma CreateGemma(const LoaderArgs& loader, static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
PerClusterPools& pools) {
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools); return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
} }
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader, static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
PerClusterPools& pools) { NestedPools& pools) {
return std::make_unique<Gemma>(loader.tokenizer, loader.weights, return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.Info(), pools); loader.Info(), pools);
} }

View File

@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_
@ -32,252 +30,45 @@
namespace gcpp { namespace gcpp {
// DEPRECATED, will be replaced by NestedPools once MatMul is updated.
// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an
// 'outer' thread pool with one worker per cluster.
//
// Useful for hierarchical parallelism, which makes sense when there are few
// but large tasks which should be parallelized by workers sharing a cache.
// This also implies lower latency for barrier synchronization of those workers.
class PerClusterPools {
using LPS = hwy::LogicalProcessorSet;
static inline std::vector<size_t> CoresInLPS(const LPS& cluster) {
std::vector<size_t> cores;
cores.reserve(cluster.Count());
cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); });
return cores;
}
using CoreBitSets = std::vector<LPS>;
// Returns empty if detection failed.
CoreBitSets DetectCoresPerCluster() {
CoreBitSets clusters;
if (!have_threading_support_) return clusters;
// Which processors are not disabled via OS, taskset, or numactl.
LPS enabled;
// If we don't know, better to abort rather than risk oversubscribing.
if (!GetThreadAffinity(enabled)) return clusters;
hwy::Topology topology;
if (topology.packages.empty()) return clusters;
// Merge all clusters into one set, as a stopgap to emulate gemma-inl's
// prior single pool.
// TODO: remove once MatMul supports hierarchical parallelism.
LPS all;
// For each cluster, add its enabled *cores*.
for (const hwy::Topology::Package& package : topology.packages) {
for (const hwy::Topology::Cluster& cluster : package.clusters) {
cluster.lps.Foreach([&](size_t lp) {
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
all.Set(lp);
}
});
}
/* code to reinstate:
for (const hwy::Topology::Cluster& cluster : package.clusters) {
// Only use enabled *cores*, and only add if not empty.
cluster.lps.Foreach([&](size_t lp) {
if (enabled.Get(lp) && topology.lps[lp].smt == 0) {
all.Set(lp);
}
});
if (lps.Any()) clusters.push_back(lps);
}
*/
}
if (all.Any()) clusters.push_back(all);
// Sort by descending number of enabled cores, so that we preferentially
// use the largest clusters.
std::sort(clusters.begin(), clusters.end(),
[](const LPS& a, const LPS& b) { return a.Count() > b.Count(); });
return clusters;
}
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
outer_pool_.SetWaitMode(wait_mode);
for (auto& inner : inner_pools_) {
inner->SetWaitMode(wait_mode);
}
}
// `user_max_or_zero` == 0 means no limit, which is the case for the defaults
// of `AppArgs` `max_clusters` and `num_threads`.
static inline size_t CapIfNonZero(size_t num_workers,
size_t user_max_or_zero) {
return (user_max_or_zero == 0) ? num_workers
: HWY_MIN(num_workers, user_max_or_zero);
}
// Returns the number of threads for `ThreadPool` to create: zero if there is
// no threading support, otherwise the capped number of workers minus the
// caller of `ThreadPool::Run`, which is the outer worker or main thread.
size_t CappedNumThreads(size_t num_workers, size_t user_max_or_zero) const {
if (!have_threading_support_) return 0;
const size_t capped_num_workers =
CapIfNonZero(num_workers, user_max_or_zero);
// Avoid underflow if number of workers is zero.
return capped_num_workers == 0 ? 0 : capped_num_workers - 1;
}
// Returns the number of workers for the inner pool whose index is `outer`, or
// 0 to indicate no limit if `max_threads` is zero.
size_t MaxInnerWorkers(const size_t max_threads, const size_t outer_workers,
const size_t outer) const {
HWY_DASSERT(outer < outer_workers);
if (max_threads == 0) return 0; // no limit
// Round down so we do not exceed the max.
const size_t max_threads_per_outer = max_threads / outer_workers;
// First outer pool gets the remainder.
const size_t remainder = (outer == 0) ? (max_threads % outer_workers) : 0;
return 1 + max_threads_per_outer + remainder;
}
public:
// Move-only.
PerClusterPools() = delete;
PerClusterPools(const PerClusterPools&) = delete;
PerClusterPools& operator=(const PerClusterPools&) = delete;
PerClusterPools(PerClusterPools&&) = delete;
PerClusterPools& operator=(PerClusterPools&&) = delete;
// PerClusterPools supports spin waits (see StartSpinning below). To prevent
// drastic slowdowns caused by excessive user-specified thread counts, which
// result in threads not running on their own core, we only allow for
// *upper bounds* on the number of clusters and threads. The actual number of
// clusters and threads are still limited by the detected topology.
// `max_threads` is the upper bound on threads to distribute among clusters,
// not including the one outer thread per cluster.
//
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1)
: have_threading_support_(hwy::HaveThreadingSupport()),
cores_per_cluster_(DetectCoresPerCluster()),
outer_pool_(CappedNumThreads(cores_per_cluster_.size(), max_clusters)) {
// Topology detection failed - it currently requires Linux.
if (cores_per_cluster_.empty()) {
// Create a single inner pool with up to TotalLogicalProcessors() / 2
// workers, further limited by `max_threads` if nonzero, and then pin to
// the first N processors, which are typically on the first socket.
const size_t num_threads =
CappedNumThreads(hwy::TotalLogicalProcessors() / 2, max_threads);
if (pin == -1) pin = num_threads > 8;
fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n",
num_threads, pin);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
if (num_threads > 1 && pin) {
inner_pools_.back()->Run(0, num_threads,
[](uint64_t /*task*/, size_t thread) {
hwy::PinThreadToLogicalProcessor(thread);
});
}
return;
}
for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) {
const size_t max_inner_workers =
MaxInnerWorkers(max_threads, outer_pool_.NumWorkers(), outer);
const size_t num_threads = CappedNumThreads(
cores_per_cluster_[outer].Count(), max_inner_workers);
inner_pools_.push_back(std::make_unique<hwy::ThreadPool>(num_threads));
}
if (pin == -1) {
pin = (outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers()) >= 12;
}
if (pin) {
// For each inner pool, pin their threads AND the associated outer thread
// (the one calling inner.Run()) to the enabled cores in the cluster.
outer_pool_.Run(
0, outer_pool_.NumWorkers(),
[this](uint64_t outer, size_t outer_thread) {
HWY_ASSERT(outer == outer_thread); // each outer has one task
hwy::ThreadPool& inner = *inner_pools_[outer];
const std::vector<size_t> cores =
CoresInLPS(cores_per_cluster_[outer]);
// May have been capped by max_threads.
HWY_ASSERT(inner.NumWorkers() <= cores.size());
inner.Run(0, inner.NumWorkers(),
[&cores](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each inner has one task
hwy::PinThreadToLogicalProcessor(cores[task]);
});
});
}
}
// Spinning reduces the latency of barrier synchronization, but wastes lots of
// energy for long waits, so only do it during generation. This might also be
// unsafe in virtualized environments because we require threads to be running
// on their own core and thus responsive to the barrier synchronization.
void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); }
void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); }
// Bitset of cores, one per cluster, or empty if detection failed. Useful for
// displaying the topology.
const CoreBitSets& CoresPerCluster() const { return cores_per_cluster_; }
hwy::ThreadPool& Outer() { return outer_pool_; }
hwy::ThreadPool& Inner(size_t outer) {
HWY_ASSERT(outer < Outer().NumWorkers());
return *inner_pools_[outer];
}
// Returns number of logical processors, for allocating per-thread buffers.
size_t NumLP() const {
return outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers();
}
private:
bool have_threading_support_;
CoreBitSets cores_per_cluster_;
hwy::ThreadPool outer_pool_;
// hwy::ThreadPool is unfortunately not marked as movable, so we have to use
// unique_ptr.
std::vector<std::unique_ptr<hwy::ThreadPool>> inner_pools_;
};
// A slice of a 1D integer range such as the indices of packages or clusters. // A slice of a 1D integer range such as the indices of packages or clusters.
// This allows assigning them to multiple instances of our binary. // This allows assigning them to multiple instances of our binary.
struct BoundedSlice { class BoundedSlice {
public:
// Defaults to "use all detected". // Defaults to "use all detected".
BoundedSlice(size_t skip = 0, size_t max = 0) : skip(skip), max(max) {} BoundedSlice(size_t skip = 0, size_t max = 0) : skip_(skip), max_(max) {}
// How many to skip, or equivalently, index of the first to use. It is an size_t Begin() const { return skip_; }
// error if this is >= `detected`, because that would leave none for this
// instance to use.
size_t skip;
// Upper bound on the number to use, or zero if no limit.
size_t max;
// STL-style one past the end. // STL-style one past the end.
size_t End(size_t detected) const { size_t End(size_t detected) const {
return (max == 0) ? detected : HWY_MIN(detected, skip + max); return (max_ == 0) ? detected : HWY_MIN(detected, skip_ + max_);
} }
// Number of elements in the slice. // Number of elements in the slice.
size_t Num(size_t detected) const { return End(detected) - skip; } size_t Num(size_t detected) const { return End(detected) - Begin(); }
bool Contains(size_t detected, size_t idx) const {
return Begin() <= idx && idx < End(detected);
}
template <class Func> template <class Func>
void ForEach(const char* name, size_t detected, const Func& func) { void Foreach(const char* name, size_t detected, const Func& func) {
if (skip >= detected) { if (Begin() >= detected) {
HWY_ABORT("Invalid skip=%zu for %s, detected=%zu", skip, name, detected); HWY_ABORT("Invalid skip=%zu for %s, detected=%zu", skip_, name, detected);
} }
for (size_t i = skip; i < End(detected); ++i) { for (size_t i = Begin(); i < End(detected); ++i) {
func(i); func(i);
} }
} }
private:
// How many to skip, or equivalently, index of the first to use. It is an
// error if this is >= `detected`, because that would leave none for this
// instance to use.
size_t skip_;
// Upper bound on the number to use, or zero if no limit.
size_t max_;
}; };
// "LP" is a logical processor, a 0-based index passed to the OS. // "LP" is a logical processor, a 0-based index passed to the OS.
@ -288,201 +79,255 @@ using LPS = hwy::LogicalProcessorSet;
// NOTE: if topology is unknown or the OS affinity is too restrictive, we fall // NOTE: if topology is unknown or the OS affinity is too restrictive, we fall
// back to a single package and cluster. // back to a single package and cluster.
class BoundedTopology { class BoundedTopology {
// Sort packages/clusters by descending size so that users who only use one
// get the largest.
template <class Group>
static void SortByDescendingLPs(std::vector<Group>& groups) {
std::sort(groups.begin(), groups.end(), [](const Group& a, const Group& b) {
return a.num_lps > b.num_lps;
});
}
public: public:
struct Cluster {
// Simple version when topology is unknown.
explicit Cluster(size_t num_workers) : num_lps(num_workers) {
HWY_ASSERT(num_lps != 0);
}
Cluster(const std::vector<hwy::Topology::LP>& all_lps, const LPS& enabled,
size_t package_lp, const hwy::Topology::Cluster& cluster,
LPS& package_lps) {
// All first-hyperthread LPs from the cluster that are enabled and not
// already in use as the package representative.
cluster.lps.Foreach([&](size_t lp) {
if (all_lps[lp].smt == 0 && enabled.Get(lp) && lp != package_lp) {
HWY_ASSERT(!lps.Get(lp));
lps.Set(lp);
HWY_ASSERT(!package_lps.Get(lp));
package_lps.Set(lp);
}
});
num_lps = lps.Count(); // = 0 if all disabled.
}
LPS lps;
size_t num_lps;
// Set by caller to the first of `lps` if there are multiple clusters in a
// package.
size_t cluster_lp = 0;
};
struct Package {
// Simple version when topology is unknown.
explicit Package(size_t num_workers) {
package_lp = 0;
num_lps = num_workers;
clusters.push_back(Cluster(num_workers));
}
Package(size_t package_idx, const hwy::Topology& topology,
const LPS& enabled, BoundedSlice cluster_slice) {
const hwy::Topology::Package& package = topology.packages[package_idx];
package_lp = package.clusters[0].lps.First();
cluster_slice.ForEach(
"cluster", package.clusters.size(), [&](size_t cluster_idx) {
Cluster cluster(topology.lps, enabled, package_lp,
package.clusters[cluster_idx], lps);
if (HWY_LIKELY(cluster.num_lps != 0)) {
num_lps += cluster.num_lps; // before std::move
clusters.push_back(std::move(cluster));
}
});
// Note that it is possible for `clusters` to be empty if its LPs are all
// disabled. If so, the caller will ignore topology and create a single
// package and cluster.
SortByDescendingLPs(clusters);
// If there are multiple clusters, set their first LP to represent the
// cluster and mark them as unavailable for its pool.
if (clusters.size() > 1) {
for (Cluster& cluster : clusters) {
cluster.cluster_lp = cluster.lps.First();
// Nonzero because if lp == 0 were enabled, it would be used as
// `package_lp` and excluded from `cluster.lps`.
HWY_ASSERT(cluster.cluster_lp != 0);
HWY_ASSERT(cluster.cluster_lp != package_lp);
cluster.lps.Clear(cluster.cluster_lp);
}
}
}
size_t package_lp;
LPS lps;
size_t num_lps = 0;
std::vector<Cluster> clusters;
};
BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice,
BoundedSlice lp_slice) { BoundedSlice lp_slice) {
const bool have_threading_support = hwy::HaveThreadingSupport(); // Regardless of topology, ignore LPs disabled via OS, taskset, or numactl.
LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl. LPS enabled_lps;
bool missing_cluster = false; if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) {
const size_t num_lps = hwy::TotalLogicalProcessors();
fprintf(
stderr,
"Warning, unknown OS affinity, considering all %zu LPs enabled\n.",
num_lps);
for (size_t lp = 0; lp < hwy::TotalLogicalProcessors(); ++lp) {
enabled_lps.Set(lp);
}
}
if (HWY_LIKELY(have_threading_support && !topology_.packages.empty())) { // Without threading support, only keep the first enabled LP; it might still
(void)GetThreadAffinity(enabled_lps); // failure = all disabled // make sense to pin the main thread.
if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) {
HWY_ASSERT(enabled_lps.Any());
const size_t lp = enabled_lps.First();
enabled_lps = LPS();
enabled_lps.Set(lp);
}
// No effect if topology is unknown or `enabled_lps` is empty. if (HWY_LIKELY(!topology_.packages.empty())) {
package_slice.ForEach( InitFromTopology(enabled_lps, package_slice, cluster_slice);
"package", topology_.packages.size(), [&](size_t package_idx) { }
Package package(package_idx, topology_, enabled_lps, cluster_slice);
// Skip if empty - can happen due to `enabled_lps`. // Topology unknown or no packages with enabled LPs: create a single
if (HWY_LIKELY(!package.clusters.empty())) { // package with one cluster, and one node.
total_lps_ += package.num_lps; // before std::move if (HWY_UNLIKELY(NumPackages() == 0)) {
packages_.push_back(std::move(package)); InitFromSlice(enabled_lps, lp_slice);
}
HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0);
}
size_t NumPackages() const { return packages_.size(); }
const char* TopologyString() const { return topology_string_; }
size_t NumNodes() const { return nodes_.Count(); }
class Cluster {
public:
// Topology is unknown, rely on OS affinity and user-specified slice.
Cluster(const LPS& enabled_lps, BoundedSlice lp_slice) {
// Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so
// we honor both the OS affinity and the user-specified slice. Note that
// this can be used to exclude hyperthreads because Linux groups LPs by
// sibling index. For example, the first `num_cores` are not siblings.
const size_t detected = enabled_lps.Count();
size_t enabled_idx = 0;
enabled_lps.Foreach([&](size_t lp) {
if (lp_slice.Contains(detected, enabled_idx++)) {
AddLP(lp);
} }
}); });
for (Package& package : packages_) { // lp_slice can only reduce the number of `enabled_lps`, and not below 1.
missing_cluster = package.clusters.empty(); HWY_ASSERT(num_workers_ != 0);
if (HWY_UNLIKELY(missing_cluster)) {
fprintf(
stderr,
"Warning, found no clusters for package with %zu LPs.\nWe will "
"ignore topology and assume a single package/cluster.\n",
package.num_lps);
break;
}
}
} }
// Topology unknown or any package ended up empty: create a single package Cluster(const LPS& enabled_lps,
// with one cluster. const std::vector<hwy::Topology::LP>& all_lps,
if (HWY_UNLIKELY(packages_.empty() || missing_cluster)) { const hwy::Topology::Cluster& tcluster) {
// We do not bother to detect hyperthreads. Not all CPUs have two per bool is_first_lp = true;
// core, so instead of dividing, rely on the user's `lp_slice.max`. This
// works because Linux groups LPs by HT. tcluster.lps.Foreach([&](size_t lp) {
const size_t num_lps = have_threading_support // Skip if not first-hyperthread or disabled.
? lp_slice.Num(hwy::TotalLogicalProcessors()) if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return;
: 1;
packages_.clear(); AddLP(lp);
packages_.push_back(Package(num_lps));
total_lps_ = num_lps; // Set `node` once, and ensure subsequent nodes match - we assume there
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", num_lps); // is only one NUMA node per cluster.
const size_t lp_node = static_cast<size_t>(all_lps[lp].node);
if (is_first_lp) {
is_first_lp = false;
node_ = lp_node;
} else { } else {
SortByDescendingLPs(packages_); static bool warned = false;
if (lp_node != node_ && !warned) {
const hwy::Topology::Package& tpackage0 = topology_.packages[0]; warned = true;
HWY_ASSERT(!tpackage0.clusters.empty()); fprintf(stderr,
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0]; "WARNING: lp %zu on node %zu != cluster node %zu.\n", lp,
const Package& package0 = GetPackage(0); lp_node, node_);
const Cluster& cluster0 = GetCluster(0, 0); }
snprintf(topology_string_, sizeof(topology_string_), }
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(), });
tpackage0.clusters.size(), tcluster0.lps.Count(),
packages_.size(), package0.clusters.size(), cluster0.num_lps);
} }
HWY_ASSERT(NumPackages() != 0); // For SortByDescendingSize.
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) { size_t Size() const { return num_workers_; }
HWY_ASSERT(NumClusters(package_idx) != 0);
} // Returns vector with all enabled LPs, used for pinning.
std::vector<size_t> LPVector() const {
std::vector<size_t> lps;
lps.reserve(lps_.Count());
lps_.Foreach([&lps](size_t lp) { lps.push_back(lp); });
return lps;
} }
const char* TopologyString() const { return topology_string_; } size_t Node() const { return node_; }
size_t NumPackages() const { return packages_.size(); } private:
const Package& GetPackage(size_t package_idx) const { void AddLP(size_t lp) {
HWY_ASSERT(package_idx < NumPackages()); HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness
return packages_[package_idx]; lps_.Set(lp);
} ++num_workers_;
Package& GetPackage(size_t package_idx) {
HWY_ASSERT(package_idx < NumPackages());
return packages_[package_idx];
} }
// Enabled LPs; if topology is known, only the ones in this cluster.
LPS lps_;
// How many workers in the per-cluster pool. If 0, this Cluster is removed.
size_t num_workers_ = 0;
// NUMA node, set from hwy::Topology::LP::node.
size_t node_ = 0;
}; // Cluster
size_t NumClusters(size_t package_idx) const { size_t NumClusters(size_t package_idx) const {
return GetPackage(package_idx).clusters.size(); HWY_ASSERT(package_idx < NumPackages());
return packages_[package_idx].clusters.size();
} }
const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const { const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const {
const Package& package = GetPackage(package_idx); HWY_ASSERT(package_idx < NumPackages());
const Package& package = packages_[package_idx];
HWY_ASSERT(cluster_idx < package.clusters.size()); HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx]; return package.clusters[cluster_idx];
} }
Cluster& GetCluster(size_t package_idx, size_t cluster_idx) { Cluster& GetCluster(size_t package_idx, size_t cluster_idx) {
Package& package = GetPackage(package_idx); HWY_ASSERT(package_idx < NumPackages());
Package& package = packages_[package_idx];
HWY_ASSERT(cluster_idx < package.clusters.size()); HWY_ASSERT(cluster_idx < package.clusters.size());
return package.clusters[cluster_idx]; return package.clusters[cluster_idx];
} }
// Returns number of logical processors, for allocating per-thread buffers. // Returns total number of cluster workers, for deciding whether to pin.
size_t NumLP() const { return total_lps_; } size_t TotalWorkers() const {
size_t total_workers = 0;
for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) {
const size_t num_clusters = NumClusters(package_idx);
for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) {
total_workers += GetCluster(package_idx, cluster_idx).Size();
}
}
return total_workers;
}
private: private:
// Sort T := packages/clusters by descending 'size' so that users who only use
// one Group get the largest.
template <class T>
static void SortByDescendingSize(std::vector<T>& groups) {
std::sort(groups.begin(), groups.end(),
[](const T& a, const T& b) { return a.Size() > b.Size(); });
}
struct Package {
// Topology is unknown, rely on OS affinity and user-specified slice.
Package(const LPS& enabled_lps, BoundedSlice lp_slice) {
clusters.push_back(Cluster(enabled_lps, lp_slice));
}
// NOTE: caller is responsible for checking whether `clusters` is empty.
Package(const LPS& enabled_lps, const hwy::Topology& topology,
size_t package_idx, BoundedSlice cluster_slice) {
const hwy::Topology::Package& tpackage = topology.packages[package_idx];
// Populate `clusters` with the subset of clusters in `cluster_slice` that
// have any enabled LPs. If `clusters` remains empty, the caller will
// skip this `Package`.
clusters.reserve(cluster_slice.Num(tpackage.clusters.size()));
cluster_slice.Foreach(
"cluster", tpackage.clusters.size(), [&](size_t cluster_idx) {
const hwy::Topology::Cluster& tcluster =
tpackage.clusters[cluster_idx];
Cluster cluster(enabled_lps, topology.lps, tcluster);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(cluster.Size() != 0)) {
clusters.push_back(std::move(cluster));
}
});
SortByDescendingSize(clusters);
}
// For SortByDescendingSize.
size_t Size() const { return clusters.size(); }
std::vector<Cluster> clusters;
}; // Package
// Main part of ctor, called when topology is known.
void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice,
BoundedSlice cluster_slice) {
// (Possibly empty) subset of `Topology` packages that have `enabled_lps`.
package_slice.Foreach(
"package", topology_.packages.size(), [&](size_t package_idx) {
Package package(enabled_lps, topology_, package_idx, cluster_slice);
// Skip if empty, i.e. too few `enabled_lps`.
if (HWY_LIKELY(!package.clusters.empty())) {
packages_.push_back(std::move(package));
}
});
if (NumPackages() == 0) return;
SortByDescendingSize(packages_);
const hwy::Topology::Package& tpackage0 = topology_.packages[0];
HWY_ASSERT(!tpackage0.clusters.empty());
const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0];
// GetCluster(0, 0) is valid because only non-empty Packages were kept.
snprintf(topology_string_, sizeof(topology_string_),
"%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(),
tpackage0.clusters.size(), tcluster0.lps.Count(), packages_.size(),
NumClusters(0), GetCluster(0, 0).Size());
// Remember NUMA nodes of *enabled* LPs.
enabled_lps.Foreach([&](size_t lp) {
nodes_.Set(static_cast<size_t>(topology_.lps[lp].node));
});
}
void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice) {
packages_.push_back(Package(enabled_lps, lp_slice));
snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu",
GetCluster(0, 0).Size());
// Assume a single NUMA node.
nodes_.Set(0);
HWY_ASSERT(NumNodes() == 1);
}
hwy::Topology topology_; hwy::Topology topology_;
size_t total_lps_ = 0;
std::vector<Package> packages_; std::vector<Package> packages_;
char topology_string_[96]; char topology_string_[96];
LPS nodes_;
}; };
// Creates a hierarchy of thread pools according to BoundedTopology: one with a // Creates a hierarchy of thread pools according to `BoundedTopology`: one with
// thread per enabled package; for each of those, one with a thread per enabled // a thread per enabled package; for each of those, one with a thread per
// cluster (CCX/shared L3), and for each of those, the remaining enabled cores // enabled cluster (CCX/shared L3), and for each of those, the remaining
// in that cluster. The cores representing each package and cluster are not // enabled cores in that cluster.
// included in the per-cluster pool because we support spin-waiting, hence //
// there should be at most one thread per HW core. // Note that we support spin waits, thus it is important for each thread to be
// responsive, hence we do not create more than one thread per enabled core.
// For example, when there are two packages with four clusters of 8 cores,
// `AllPackages` has the main thread plus one extra thread, each `AllClusters`
// has one of the `AllPackages` threads plus three extras, each `Cluster` runs
// on one `AllClusters` thread plus seven extra workers, for a total of
// 1 + 2*3 + 2*(4*7) = 63 extras plus the main thread.
// //
// Useful when there are tasks which should be parallelized by workers sharing a // Useful when there are tasks which should be parallelized by workers sharing a
// cache, or on the same NUMA node. In both cases, individual pools have lower // cache, or on the same NUMA node. In both cases, individual pools have lower
@ -498,14 +343,13 @@ class NestedPools {
NestedPools& operator=(NestedPools&&) = delete; NestedPools& operator=(NestedPools&&) = delete;
// `max_threads` is the maximum number of threads to divide among all // `max_threads` is the maximum number of threads to divide among all
// clusters. It does not include the package and cluster representatives. // clusters. This is more intuitive than a per-cluster limit for users who
// This is more intuitive than a per-cluster limit for users who may not be // may not be aware of the CPU topology.
// aware of the CPU topology.
// //
// To ensure we do not create more threads than there are HW cores, which // To ensure we do not create more threads than there are HW cores, which
// would cause huge slowdowns when spinning, `BoundedSlice` imposes upper // would cause huge slowdowns when spinning, the `BoundedSlice` arguments
// bounds on the number of detected packages and clusters rather than // only impose upper bounds on the number of detected packages and clusters
// defining an exact amount. // rather than defining the actual number of threads.
// //
// `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically. // `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically.
NestedPools(size_t max_threads, int pin = -1, NestedPools(size_t max_threads, int pin = -1,
@ -513,12 +357,14 @@ class NestedPools {
BoundedSlice cluster_slice = BoundedSlice(), BoundedSlice cluster_slice = BoundedSlice(),
BoundedSlice lp_slice = BoundedSlice()) BoundedSlice lp_slice = BoundedSlice())
: topology_(package_slice, cluster_slice, lp_slice) { : topology_(package_slice, cluster_slice, lp_slice) {
if (pin == -1) pin = topology_.NumLP() >= 12; if (pin == -1) pin = topology_.TotalWorkers() >= 12;
packages_.resize(topology_.NumPackages()); packages_.resize(topology_.NumPackages());
all_packages_ = MakePool(packages_.size()); all_packages_ = MakePool(packages_.size());
const size_t max_workers_per_package = max_threads / packages_.size(); const size_t max_workers_per_package = max_threads / packages_.size();
// Parallel to ensure we also pin the calling (main) thread. // Each worker in all_packages_, including the main thread, will be the
// calling thread of an all_clusters->Run, and hence pinned to one of the
// `cluster.lps` if `pin`.
all_packages_->Run( all_packages_->Run(
0, all_packages_->NumWorkers(), 0, all_packages_->NumWorkers(),
[&](uint64_t package_idx, size_t thread) { [&](uint64_t package_idx, size_t thread) {
@ -526,10 +372,24 @@ class NestedPools {
packages_[package_idx] = Package( packages_[package_idx] = Package(
topology_, package_idx, max_workers_per_package, pin, lp_slice); topology_, package_idx, max_workers_per_package, pin, lp_slice);
}); });
// For mapping package/cluster/thread to noncontiguous TLS indices, in case
// cluster/thread counts differ.
HWY_ASSERT(!packages_.empty() && packages_.size() <= 16);
for (const Package& p : packages_) {
max_clusters_per_package_ =
HWY_MAX(max_clusters_per_package_, p.NumClusters());
max_workers_per_cluster_ =
HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster());
}
HWY_ASSERT(max_clusters_per_package_ >= 1);
HWY_ASSERT(max_clusters_per_package_ <= 64);
HWY_ASSERT(max_workers_per_cluster_ >= 1);
HWY_ASSERT(max_workers_per_cluster_ <= 256);
} }
// Spinning reduces the latency of barrier synchronization, but wastes lots // Spinning reduces the latency of barrier synchronization, but wastes lots
// of energy for long waits, so only do it during generation. This might // of energy for long waits, so only do it during generation. Spinning might
// also be unsafe in virtualized environments because we require threads to // also be unsafe in virtualized environments because we require threads to
// be running on their own core and thus responsive to the barrier // be running on their own core and thus responsive to the barrier
// synchronization. // synchronization.
@ -538,18 +398,45 @@ class NestedPools {
hwy::ThreadPool& AllPackages() { return *all_packages_; } hwy::ThreadPool& AllPackages() { return *all_packages_; }
hwy::ThreadPool& AllClusters(size_t package_idx) { hwy::ThreadPool& AllClusters(size_t package_idx) {
HWY_ASSERT(package_idx < AllPackages().NumWorkers()); HWY_DASSERT(package_idx < packages_.size());
return *packages_[package_idx].all_clusters; return packages_[package_idx].AllClusters();
} }
hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) { hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) {
HWY_ASSERT(cluster_idx < AllClusters(package_idx).NumWorkers()); HWY_DASSERT(package_idx < packages_.size());
return *packages_[package_idx].clusters[cluster_idx]; return packages_[package_idx].Cluster(cluster_idx);
} }
// For binding to NUMA nodes.
size_t Node(size_t package_idx, size_t cluster_idx) const {
return topology_.GetCluster(package_idx, cluster_idx).Node();
}
// Reasonably tight upper bound for allocating thread-local storage (TLS).
size_t MaxWorkers() const {
return packages_.size() * max_clusters_per_package_ *
max_workers_per_cluster_;
}
// Returns the first of `cluster.NumWorkers()` TLS indices, to which callers
// add the worker index given by `cluster.Run`.
size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const {
return (package_idx * max_clusters_per_package_ + cluster_idx) *
max_workers_per_cluster_;
}
// For Allocator
const BoundedTopology& Topology() const { return topology_; }
const char* TopologyString() const { return topology_.TopologyString(); } const char* TopologyString() const { return topology_.TopologyString(); }
// Returns number of logical processors, for allocating per-thread buffers. // Returns a single pool on the first package: either one thread per cluster
size_t NumLP() const { return topology_.NumLP(); } // if there is more than one, which maximizes available memory bandwidth, or
// the first cluster, which is typically the whole package. For use by callers
// that only parallelize over a 1D range, as opposed to the nested
// parallelism of `StaticPartitionRowsAndCols`.
hwy::ThreadPool& Pool() {
// Only one cluster: use its pool, typically a whole socket.
if (AllClusters(0).NumWorkers() == 1) return Cluster(0, 0);
return AllClusters(0);
}
private: private:
// `max_or_zero` == 0 means no limit. // `max_or_zero` == 0 means no limit.
@ -569,69 +456,72 @@ class NestedPools {
} }
class Package { class Package {
static PoolPtr CreateClusterPool(const BoundedTopology::Cluster& cluster,
size_t max_cluster_workers, int pin,
BoundedSlice lp_slice) {
PoolPtr pool =
MakePool(CapIfNonZero(cluster.num_lps, max_cluster_workers));
if (!pin) return pool;
// Else: pin all new threads AND the calling thread from `all_clusters`.
// We know the topology: pin to this cluster's cores, including the
// calling thread from `all_clusters`.
if (cluster.lps.Any()) {
std::vector<size_t> lps;
lps.reserve(cluster.num_lps);
cluster.lps.Foreach([&lps](size_t lp) { lps.push_back(lp); });
pool->Run(0, pool->NumWorkers(), [&lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
hwy::PinThreadToLogicalProcessor(lps[task]);
});
} else {
// Pin to consecutive LPs.
pool->Run(0, pool->NumWorkers(),
[lp_slice](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
hwy::PinThreadToLogicalProcessor(lp_slice.skip + thread);
});
}
return pool;
}
public: public:
Package() = default; // for vector Package() = default; // for vector
Package(const BoundedTopology& topology, size_t package_idx, Package(const BoundedTopology& topology, size_t package_idx,
size_t max_workers_per_package, int pin, BoundedSlice lp_slice) { size_t max_workers_per_package, int pin, BoundedSlice lp_slice) {
clusters.resize(topology.NumClusters(package_idx)); // Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(package_idx));
const size_t max_workers_per_cluster = const size_t max_workers_per_cluster =
max_workers_per_package / clusters.size(); max_workers_per_package / clusters_.size();
all_clusters = MakePool(clusters.size()); all_clusters_ = MakePool(clusters_.size());
// Parallel so we also pin the calling thread from `all_packages_`. // Parallel so we also pin the calling worker in `all_clusters` to
all_clusters->Run( // `cluster.lps`.
0, all_clusters->NumWorkers(), all_clusters_->Run(
0, all_clusters_->NumWorkers(),
[&](size_t cluster_idx, size_t thread) { [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster = const BoundedTopology::Cluster& cluster =
topology.GetCluster(package_idx, cluster_idx); topology.GetCluster(package_idx, cluster_idx);
clusters[cluster_idx] = CreateClusterPool( clusters_[cluster_idx] =
cluster, max_workers_per_cluster, pin, lp_slice); MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster));
if (HWY_LIKELY(pin)) {
// Pin threads AND the calling thread from `all_clusters` to lps.
const std::vector<size_t> lps = cluster.LPVector();
HWY_ASSERT(clusters_[cluster_idx]->NumWorkers() <= lps.size());
clusters_[cluster_idx]->Run(
0, clusters_[cluster_idx]->NumWorkers(),
[&lps](uint64_t task, size_t thread) {
HWY_ASSERT(task == thread); // each worker has one task
hwy::PinThreadToLogicalProcessor(lps[task]);
});
}
}); });
} }
std::vector<PoolPtr> clusters; size_t NumClusters() const { return clusters_.size(); }
PoolPtr all_clusters; size_t MaxWorkersPerCluster() const {
}; size_t max_workers_per_cluster = 0;
for (const PoolPtr& cluster : clusters_) {
max_workers_per_cluster =
HWY_MAX(max_workers_per_cluster, cluster->NumWorkers());
}
return max_workers_per_cluster;
}
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
hwy::ThreadPool& Cluster(size_t cluster_idx) {
HWY_DASSERT(cluster_idx < clusters_.size());
return *clusters_[cluster_idx];
}
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_clusters_->SetWaitMode(wait_mode);
for (PoolPtr& cluster : clusters_) {
cluster->SetWaitMode(wait_mode);
}
}
private:
std::vector<PoolPtr> clusters_;
PoolPtr all_clusters_;
}; // Package
void SetWaitMode(hwy::PoolWaitMode wait_mode) { void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_packages_->SetWaitMode(wait_mode); all_packages_->SetWaitMode(wait_mode);
for (Package& package : packages_) { for (Package& package : packages_) {
package.all_clusters->SetWaitMode(wait_mode); package.SetWaitMode(wait_mode);
for (PoolPtr& cluster : package.clusters) {
cluster->SetWaitMode(wait_mode);
}
} }
} }
@ -639,12 +529,11 @@ class NestedPools {
std::vector<Package> packages_; std::vector<Package> packages_;
PoolPtr all_packages_; PoolPtr all_packages_;
};
static inline NestedPools CreateSinglePool(size_t max_threads, int pin = -1) { // For TLS indices.
const BoundedSlice one(0, 1); size_t max_clusters_per_package_ = 0;
return NestedPools(max_threads, pin, one, one); size_t max_workers_per_cluster_ = 0;
} };
} // namespace gcpp } // namespace gcpp

View File

@ -35,27 +35,41 @@ TEST(ThreadingTest, TestBoundedSlice) {
{ {
BoundedSlice slice; BoundedSlice slice;
std::vector<size_t> expected; std::vector<size_t> expected;
slice.ForEach(name, 10, [&](size_t i) { expected.push_back(i); }); const size_t detected = 10;
EXPECT_EQ(10, slice.Num(10)); slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(10, slice.Num(detected));
EXPECT_THAT(expected, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)); EXPECT_THAT(expected, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
EXPECT_TRUE(slice.Contains(detected, 0));
EXPECT_TRUE(slice.Contains(detected, 9));
EXPECT_FALSE(slice.Contains(detected, 10));
} }
// One arg: skip first N // One arg: skip first N
{ {
BoundedSlice slice(3); BoundedSlice slice(3);
std::vector<size_t> expected; std::vector<size_t> expected;
slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); }); const size_t detected = 9;
EXPECT_EQ(6, slice.Num(9)); slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(6, slice.Num(detected));
EXPECT_THAT(expected, ElementsAre(3, 4, 5, 6, 7, 8)); EXPECT_THAT(expected, ElementsAre(3, 4, 5, 6, 7, 8));
EXPECT_FALSE(slice.Contains(detected, 2));
EXPECT_TRUE(slice.Contains(detected, 3));
EXPECT_TRUE(slice.Contains(detected, 8));
EXPECT_FALSE(slice.Contains(detected, 9));
} }
// Both args: skip first N, then use at most M // Both args: skip first N, then use at most M
{ {
BoundedSlice slice(3, 2); BoundedSlice slice(3, 2);
std::vector<size_t> expected; std::vector<size_t> expected;
slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); }); const size_t detected = 9;
EXPECT_EQ(2, slice.Num(9)); slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(2, slice.Num(detected));
EXPECT_THAT(expected, ElementsAre(3, 4)); EXPECT_THAT(expected, ElementsAre(3, 4));
EXPECT_FALSE(slice.Contains(detected, 2));
EXPECT_TRUE(slice.Contains(detected, 3));
EXPECT_TRUE(slice.Contains(detected, 4));
EXPECT_FALSE(slice.Contains(detected, 5));
} }
// Both args, but `max > detected - skip`: fewer than limit. Note that // Both args, but `max > detected - skip`: fewer than limit. Note that
@ -63,9 +77,13 @@ TEST(ThreadingTest, TestBoundedSlice) {
{ {
BoundedSlice slice(3, 2); BoundedSlice slice(3, 2);
std::vector<size_t> expected; std::vector<size_t> expected;
slice.ForEach(name, 4, [&](size_t i) { expected.push_back(i); }); const size_t detected = 4;
EXPECT_EQ(1, slice.Num(4)); slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); });
EXPECT_EQ(1, slice.Num(detected));
EXPECT_THAT(expected, ElementsAre(3)); EXPECT_THAT(expected, ElementsAre(3));
EXPECT_FALSE(slice.Contains(detected, 2));
EXPECT_TRUE(slice.Contains(detected, 3));
EXPECT_FALSE(slice.Contains(detected, 4));
} }
} }
@ -76,8 +94,6 @@ TEST(ThreadingTest, TestBoundedTopology) {
{ {
BoundedTopology topology(all, all, all); BoundedTopology topology(all, all, all);
fprintf(stderr, "%s\n", topology.TopologyString()); fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_NE(0, topology.NumPackages());
ASSERT_NE(0, topology.NumClusters(0));
} }
// Max one package // Max one package
@ -85,14 +101,12 @@ TEST(ThreadingTest, TestBoundedTopology) {
BoundedTopology topology(one, all, all); BoundedTopology topology(one, all, all);
fprintf(stderr, "%s\n", topology.TopologyString()); fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_EQ(1, topology.NumPackages()); ASSERT_EQ(1, topology.NumPackages());
ASSERT_NE(0, topology.NumClusters(0));
} }
// Max one cluster // Max one cluster
{ {
BoundedTopology topology(all, one, all); BoundedTopology topology(all, one, all);
fprintf(stderr, "%s\n", topology.TopologyString()); fprintf(stderr, "%s\n", topology.TopologyString());
ASSERT_NE(0, topology.NumPackages());
ASSERT_EQ(1, topology.NumClusters(0)); ASSERT_EQ(1, topology.NumClusters(0));
} }
} }