From 02ce1e344f115da0a3371a1f3efb7396413ded4a Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 18 Oct 2024 08:10:44 -0700 Subject: [PATCH] Use NestedPools, add NUMA infra Improved threading.h, fix thread counts for single package/cluster systems Temporarily forces to a single socket. Prefill 29.28 tps, decode 6.92. Also fix benchmarks.cc build, update tensor allocator to Allocator PiperOrigin-RevId: 687307167 --- BUILD.bazel | 37 +- CMakeLists.txt | 1 + backprop/optimize_test.cc | 4 +- compression/BUILD.bazel | 1 + compression/compress-inl.h | 1 - compression/compress.h | 131 ++---- evals/benchmark_helper.cc | 18 +- evals/benchmark_helper.h | 4 +- examples/hello_world/run.cc | 2 +- gemma/activations.h | 2 +- gemma/gemma-inl.h | 14 +- gemma/gemma.cc | 22 +- gemma/gemma.h | 7 +- gemma/run.cc | 5 +- gemma/weights.cc | 5 +- ops/dot-inl.h | 11 - ops/matmul-inl.h | 2 +- ops/matmul.h | 26 +- ops/matmul_test.cc | 48 ++- ops/matvec-inl.h | 9 + util/allocator.cc | 183 ++++++++ util/allocator.h | 153 ++++++- util/app.h | 5 +- util/threading.h | 813 ++++++++++++++++-------------------- util/threading_test.cc | 38 +- 25 files changed, 864 insertions(+), 678 deletions(-) create mode 100644 util/allocator.cc diff --git a/BUILD.bazel b/BUILD.bazel index c480f23..9172030 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -29,10 +29,24 @@ cc_library( ) cc_library( - name = "allocator", - hdrs = ["util/allocator.h"], + name = "threading", + hdrs = ["util/threading.h"], deps = [ "@highway//:hwy", + "@highway//:thread_pool", + "@highway//:topology", + ], +) + +cc_library( + name = "allocator", + srcs = ["util/allocator.cc"], + hdrs = ["util/allocator.h"], + deps = [ + ":basics", + ":threading", + "@highway//:hwy", + "@highway//:thread_pool", ], ) @@ -46,16 +60,6 @@ cc_library( ], ) -cc_library( - name = "threading", - hdrs = ["util/threading.h"], - deps = [ - "@highway//:hwy", - "@highway//:thread_pool", - "@highway//:topology", - ], -) - cc_test( name = "threading_test", srcs = ["util/threading_test.cc"], @@ -168,6 +172,7 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ + ":allocator", ":ops", ":threading", "@googletest//:gtest_main", # buildcleaner: keep @@ -211,6 +216,7 @@ cc_library( hdrs = ["gemma/weights.h"], deps = [ ":common", + "//compression:blob_store", "//compression:compress", "//compression:io", "@highway//:hwy", @@ -393,6 +399,12 @@ cc_binary( ], ) +cc_library( + name = "benchmark_prompts", + hdrs = ["evals/prompts.h"], + deps = ["@highway//:hwy"], +) + cc_binary( name = "benchmarks", srcs = [ @@ -401,6 +413,7 @@ cc_binary( ], deps = [ ":benchmark_helper", + ":benchmark_prompts", "@google_benchmark//:benchmark", "@highway//:hwy", # base.h ], diff --git a/CMakeLists.txt b/CMakeLists.txt index bade5de..6b0b832 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,7 @@ set(SOURCES ops/sum-inl.h paligemma/image.cc paligemma/image.h + util/allocator.cc util/allocator.h util/app.h util/args.h diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index b47a48d..494f3b2 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -39,8 +39,8 @@ namespace gcpp { TEST(OptimizeTest, GradientDescent) { - PerClusterPools pools(1, 1); - hwy::ThreadPool& pool = pools.Inner(0); + NestedPools pools(1); + hwy::ThreadPool& pool = pools.Pool(); std::mt19937 gen(42); const ModelInfo info = { diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index 2e6b293..e832793 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -164,6 +164,7 @@ cc_library( ":io", ":nuq", ":sfp", + "//:allocator", "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index ef30033..7fd097c 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -197,7 +197,6 @@ struct CompressTraits { size_t num, CompressPerThread& tls, const PackedSpan& packed, const size_t packed_ofs) { - const hn::RebindToUnsigned du; const hn::Repartition dbf; const size_t NF = hn::Lanes(df); diff --git a/compression/compress.h b/compression/compress.h index adb35a1..7453310 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -20,10 +20,9 @@ #define COMPRESS_STATS 0 #include +#include #include -#include -#include #include #include #include @@ -35,70 +34,23 @@ #include "compression/io.h" #include "compression/shared.h" // IWYU pragma: end_exports -#include "compression/distortion.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" // BF16 -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "util/allocator.h" #if COMPRESS_STATS +#include "compression/distortion.h" #include "hwy/stats.h" #endif namespace gcpp { -// Compressed representation of floating-point elements. The array length may -// differ from the number of elements. Associated operations such as Dot are -// implemented in SIMD code and are thus non-member functions. -template -class CompressedArray { - public: - using value_type = Packed; - - // Note that whenever you access data(), you have to consider a scale() that - // may be different from 1.0f. - Packed* data() { return data_.data(); } - const Packed* data() const { return data_.data(); } - // The const accessor data_scale1() asserts (!) that the scale is 1.0f, so - // calling it means "I am sure the scale is 1 and therefore ignore the scale". - // A scale of 0 indicates that the scale has likely never been set, so is - // "implicitly 1". - const Packed* data_scale1() const { - HWY_ASSERT(scale() == 1.f || scale() == 0.f); - return data_.data(); - } - - // Decoded elements should be multiplied by this to restore their original - // range. This is required because SfpStream can only encode a limited range - // of magnitudes. - float scale() const { return scale_[0]; } - void set_scale(float scale) { scale_[0] = scale; } - - constexpr size_t NumElements() const { return kCapacity; } - - // Returns total number of packed elements for `BlobReader::Enqueue` and - // `Compress`. This differs from `NumElements` for `Packed=NuqStream`. - PackedSpan GetSpan() { return MakeSpan(data(), data_.size()); } - PackedSpan GetSpan() const { - return MakeSpan(data(), data_.size()); - } - - private: - std::array(kCapacity)> data_; - // Blobs are at least kBlobAlign bytes anyway. - float scale_[kBlobAlign / sizeof(float)]; -}; - -// Yet another array class. This one is intended to be compatible with -// CompressedArray, but have both run-time sizing and compile-time constant -// size. -// It also provides easy conversion from/to a table of contents for a BlobStore -// file, and a templated (compile-time) accessor for a 2-d array of fixed inner -// dimension and type. -// The base class is intended for accessing the metadata, without needing to -// know any of the template arguments. -// It holds only a borrowed pointer to the data, but all metadata. +// Base class for rank-1 or 2 tensors (vector or matrix). +// Supports both dynamic and compile-time sizing. +// Holds metadata and a non-owning pointer to the data, owned by the derived +// MatStorageT class. +// This class also provides easy conversion from/to a table of contents for a +// BlobStore file, and a templated (compile-time) accessor for a 2-d array of +// fixed inner dimension and type. // It is designed to be put in a vector, and has default copy and operator=, so // it is easy to read/write a blob_store file. -// The derived class or an external class owns the data. class MatPtr { public: // Full constructor for dynamic sizing. @@ -111,12 +63,12 @@ class MatPtr { rows_(rows), cols_(cols), ptr_(nullptr) {} - // Default constructor doesn't set anything. + // Default is to leave all fields default-initialized. MatPtr() = default; virtual ~MatPtr(); // Number of hwy::uint128_t in a TOC entry. - // Note that the old-style BlobStore files Only have a list of keys and size. + // Note that the old-style BlobStore files only have a list of keys and size. // The new-style BlobStore files have an entry called "toc" that contains a // vector of 4-tuples of // (name, type, (num_elements, element_size), (rows, cols)). @@ -144,6 +96,7 @@ class MatPtr { } // Compatibility interface for CompressedArray. + // TODO: remove. template T* data() { return HWY_RCAST_ALIGNED(T*, ptr_); @@ -177,7 +130,6 @@ class MatPtr { // Returns the number of bytes in the array. size_t SizeBytes() const { return num_elements_ * element_size_; } - size_t CompressedSize() const { return SizeBytes(); } // Returns the number of rows in the 2-d array (outer dimension). size_t Rows() const { return rows_; } @@ -211,8 +163,8 @@ class MatPtr { } // Calls func on the upcasted type. Since MatPtr by design is not templated, - // here we provide a way to get to the derived type, provided that the type - // matches one of a known short-list. + // here we provide a way to get to the derived type, provided that `Type()` + // is one of the strings returned by `TypeName()`. template decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args); @@ -243,8 +195,6 @@ class MatPtr { template class MatPtrT : public MatPtr { public: - using value_type = MatT; - // Full constructor for dynamic sizing. MatPtrT(const std::string& name, size_t rows, size_t cols) : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} @@ -276,20 +226,13 @@ class MatPtrT : public MatPtr { } return name; } + // Sets the number of elements in the array. For use when the number of // elements is != rows * cols ONLY. void SetNumElements(size_t num_elements) { num_elements_ = CompressedArrayElements(num_elements); } - // Fast 2-d accessor for a 2-d array of fixed inner dimension and type. - template - const T& AtT(size_t row, size_t col) const { - size_t index = row * kInner + col; - HWY_DASSERT(index < num_elements_); - return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; - } - // 2-d Accessor for a specific type but with a dynamic inner dimension. template const T& At(size_t row, size_t col) const { @@ -299,17 +242,15 @@ class MatPtrT : public MatPtr { } // 1-d Accessor for a specific type. - template - const T& At(size_t index) const { + // TODO: replace this with a Foreach(), or at least a ForEachRow(). + const MatT& At(size_t index) const { HWY_DASSERT(index < num_elements_); - return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; - } - template - T& At(size_t index) { - return HWY_RCAST_ALIGNED(T*, ptr_)[index]; + return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index]; } + MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; } // Compatibility interface for CompressedArray. + // TODO: remove template T* data() { return HWY_RCAST_ALIGNED(T*, ptr_); @@ -353,15 +294,14 @@ class MatStorageT : public MatPtrT { public: // Full constructor for dynamic sizing. MatStorageT(const std::string& name, size_t rows, size_t cols) - : MatPtrT(name, rows, cols), - data_(hwy::AllocateAligned( - hwy::DivCeil(this->SizeBytes(), sizeof(MatT)))) { - this->ptr_ = data_.get(); + : MatPtrT(name, rows, cols) { + Allocate(); } // Can copy the metadata, from a MatPtr, and allocate later. MatStorageT(const MatPtr& other) : MatPtrT(other) {} + ~MatStorageT() = default; - // No copying of MatStorageT as it contains big data. + // Move-only because this contains a unique_ptr. MatStorageT(const MatStorageT& other) = delete; MatStorageT& operator=(const MatStorageT& other) = delete; MatStorageT(MatStorageT&& other) = default; @@ -377,7 +317,7 @@ class MatStorageT : public MatPtrT { } else { this->num_elements_ = num_elements; } - data_ = hwy::AllocateAligned(num_elements); + data_ = Allocator::Alloc(num_elements); this->ptr_ = data_.get(); } @@ -388,8 +328,6 @@ class MatStorageT : public MatPtrT { } private: - // Aligned data array. - // std::unique_ptr data_; hwy::AlignedFreeUniquePtr data_; }; @@ -507,7 +445,7 @@ class CompressStats { }; #else struct CompressStats { - void Notify(const DistortionStats&) {} + void Notify(...) {} void NotifyIn(int) {} void Assimilate(const CompressStats&) {} void PrintAll() {} @@ -526,18 +464,17 @@ struct CompressWorkingSet { // Functor called for each tensor, which loads them and their scaling factors // from BlobStore. -class CacheLoader { +class ReadFromBlobStore { public: - explicit CacheLoader(const Path& blob_filename) { + explicit ReadFromBlobStore(const Path& blob_filename) { err_ = reader_.Open(blob_filename); - if (err_ != 0) { - fprintf(stderr, - "Cached compressed weights does not exist yet (code %d), " - "loading from file: %s.\n", - err_, blob_filename.path.c_str()); + if (HWY_UNLIKELY(err_ != 0)) { + fprintf(stderr, "Error %d opening BlobStore %s.\n", err_, + blob_filename.path.c_str()); + return; // avoid overwriting err_ to ensure ReadAll will fail. } err_ = file_toc_.LoadToc(reader_); - if (err_ != 0) { + if (HWY_UNLIKELY(err_ != 0)) { fprintf(stderr, "Found a TOC, but failed to load it (code %d)\n", err_); } } diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index abae040..1af2f1a 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -36,10 +36,8 @@ #include "util/args.h" #include "util/threading.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/contrib/thread_pool/topology.h" #include "hwy/highway.h" -#include "hwy/per_target.h" +#include "hwy/per_target.h" // VectorBytes #include "hwy/timer.h" namespace gcpp { @@ -57,7 +55,7 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, const AppArgs& app) - : pools_(app.max_clusters, app.max_threads, app.pin) { + : pools_(CreatePools(app)) { InferenceArgs mutable_inference = inference; AbortIfInvalidArgs(mutable_inference); LoaderArgs mutable_loader = loader; @@ -217,7 +215,7 @@ void LogSpeedStats(double time_start, size_t total_tokens) { } void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - PerClusterPools& pools) { + NestedPools& pools) { loader.Print(app.verbosity); inference.Print(app.verbosity); app.Print(app.verbosity); @@ -228,21 +226,15 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, char cpu100[100] = "unknown"; (void)hwy::platform::GetCpuString(cpu100); - // TODO: call TopologyString() once we have NestedPools. - const std::vector& clusters = - pools.CoresPerCluster(); - const size_t per_cluster = - clusters.empty() ? 0 : pools.CoresPerCluster().front().Count(); fprintf(stderr, "Date & Time : %s" // dt includes \n "CPU : %s\n" - "CPU topology : %zux%zu, using %zux%zu\n" + "CPU topology : %s\n" "Instruction set : %s (%zu bits)\n" "Compiled config : %s\n" "Weight Type : %s\n" "EmbedderInput Type : %s\n", - dt, cpu100, pools.CoresPerCluster().size(), per_cluster, - pools.Outer().NumWorkers(), pools.Inner(0).NumWorkers(), + dt, cpu100, pools.TopologyString(), hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8, CompiledConfig(), StringFromType(loader.Info().weight), TypeName()); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 397fc20..7e7f1bf 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -106,7 +106,7 @@ class GemmaEnv { private: // Thread pool for running inference. - PerClusterPools pools_; + NestedPools pools_; // Random number generator. std::mt19937 gen_; // The model to run inference on. @@ -121,7 +121,7 @@ class GemmaEnv { void LogSpeedStats(double time_start, size_t total_tokens); void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - PerClusterPools& pools); + NestedPools& pools); void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); } // namespace gcpp diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 2ed9b64..7b2e90f 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -54,7 +54,7 @@ int main(int argc, char** argv) { } // Instantiate model and KV Cache - gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); + gcpp::NestedPools pools = gcpp::CreatePools(app); gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::KVCache kv_cache = gcpp::KVCache::Create(model.GetModelConfig(), diff --git a/gemma/activations.h b/gemma/activations.h index 3983924..6b39854 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -94,7 +94,7 @@ struct Activations { return inv_timescale; } - void Allocate(size_t batch_size, PerClusterPools& pools) { + void Allocate(size_t batch_size, NestedPools& pools) { post_qk = layer_config.post_qk; const size_t model_dim = weights_config.model_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index b34916e..028949c 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1131,7 +1131,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(), - activations.x.Batch(i), activations.env.Pools().Outer()); + activations.x.Batch(i), activations.env.Pool()); } // Add position embeddings. AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(), @@ -1416,7 +1416,7 @@ template void GenerateSingleT(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, PerClusterPools& pools, + KVCache& kv_cache, NestedPools& pools, TimingInfo& timing_info) { constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; @@ -1440,7 +1440,7 @@ void GenerateBatchT(const ModelWeightsStorage& model, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, PerClusterPools& pools, + const KVCaches& kv_caches, NestedPools& pools, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); @@ -1477,7 +1477,7 @@ template void GenerateImageTokensT(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens, - PerClusterPools& pools) { + NestedPools& pools) { if (model.Config().vit_layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } @@ -1500,7 +1500,7 @@ void GenerateImageTokensT(const ModelWeightsStorage& model, void GenerateSingle( // NOLINT(misc-definitions-in-headers) GEMMA_TYPE, const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, - size_t prefix_end, KVCache& kv_cache, PerClusterPools& pools, + size_t prefix_end, KVCache& kv_cache, NestedPools& pools, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) (model, runtime_config, prompt, pos, prefix_end, kv_cache, pools, @@ -1512,7 +1512,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers) const RuntimeConfig& runtime_config, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, - PerClusterPools& pools, TimingInfo& timing_info) { + NestedPools& pools, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) (model, runtime_config, queries_prompt, queries_pos, queries_prefix_end, kv_caches, pools, timing_info); @@ -1521,7 +1521,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers) void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) GEMMA_TYPE, const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, PerClusterPools& pools) { + ImageTokens& image_tokens, NestedPools& pools) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT) (model, runtime_config, image, image_tokens, pools); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 0c5f089..aa01b7f 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -27,6 +27,7 @@ #include #include "compression/io.h" // Path +#include "compression/shared.h" #include "gemma/common.h" #include "gemma/weights.h" #include "ops/ops-inl.h" @@ -39,16 +40,16 @@ namespace gcpp { Gemma::Gemma(const Path& tokenizer_path, const Path& weights, - const ModelInfo& info, PerClusterPools& pools) + const ModelInfo& info, NestedPools& pools) : pools_(pools), tokenizer_(tokenizer_path), info_(info) { - model_.Load(weights, info.model, info.weight, pools_.Inner(0)); + model_.Load(weights, info.model, info.weight, pools_.Pool()); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, - PerClusterPools& pools) + NestedPools& pools) : pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) { HWY_ASSERT(info.weight == Type::kF32); - model_.Allocate(info.model, info.weight, pools_.Inner(0)); + model_.Allocate(info.model, info.weight, pools_.Pool()); } Gemma::~Gemma() { @@ -63,17 +64,16 @@ Gemma::~Gemma() { const RuntimeConfig& runtime_config, \ const PromptTokens& prompt, size_t pos, \ size_t prefix_end, KVCache& kv_cache, \ - PerClusterPools& pools, TimingInfo& timing_info); \ + NestedPools& pools, TimingInfo& timing_info); \ extern void GenerateBatch( \ TWEIGHT, const ModelWeightsStorage& model, \ const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \ - const KVCaches& kv_caches, PerClusterPools& pools, \ - TimingInfo& timing_info); \ + const KVCaches& kv_caches, NestedPools& pools, TimingInfo& timing_info); \ extern void GenerateImageTokens( \ TWEIGHT, const ModelWeightsStorage& model, \ const RuntimeConfig& runtime_config, const Image& image, \ - ImageTokens& image_tokens, PerClusterPools& pools); + ImageTokens& image_tokens, NestedPools& pools); GEMMA_DECLARE(float) GEMMA_DECLARE(BF16) GEMMA_DECLARE(NuqStream) @@ -85,7 +85,7 @@ struct GenerateSingleT { void operator()(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, PerClusterPools& pools, + KVCache& kv_cache, NestedPools& pools, TimingInfo& timing_info) const { GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end, kv_cache, pools, timing_info); @@ -99,7 +99,7 @@ struct GenerateBatchT { const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, PerClusterPools& pools, + const KVCaches& kv_caches, NestedPools& pools, TimingInfo& timing_info) const { GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos, queries_prefix_end, kv_caches, pools, timing_info); @@ -110,7 +110,7 @@ template struct GenerateImageTokensT { void operator()(const ModelWeightsStorage& model, const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, PerClusterPools& pools) const { + ImageTokens& image_tokens, NestedPools& pools) const { GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens, pools); } diff --git a/gemma/gemma.h b/gemma/gemma.h index ce7b835..cee99f3 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -183,11 +183,10 @@ struct TimingInfo { class Gemma { public: Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, - PerClusterPools& pools); + NestedPools& pools); // Allocates weights, caller is responsible for filling them. - Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, - PerClusterPools& pools); + Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, NestedPools& pools); ~Gemma(); const ModelConfig& GetModelConfig() const { return model_.Config(); } @@ -229,7 +228,7 @@ class Gemma { const Image& image, ImageTokens& image_tokens); private: - PerClusterPools& pools_; + NestedPools& pools_; GemmaTokenizer tokenizer_; // Type-erased so that this can be defined in the header. diff --git a/gemma/run.cc b/gemma/run.cc index 6f53fc7..ecbdcce 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -188,9 +188,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const InferenceArgs& args, void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { PROFILER_ZONE("Run.misc"); + // TODO: remove once MatMul is updated. + app.max_packages = 1; // Note that num_threads is an upper bound; we also limit to the number of // detected and enabled cores. - PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); + NestedPools pools = CreatePools(app); + Allocator::Init(pools.Topology()); Gemma model = CreateGemma(loader, pools); KVCache kv_cache = diff --git a/gemma/weights.cc b/gemma/weights.cc index de54ef3..e0f4d8c 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -21,6 +21,7 @@ #include #include +#include "compression/blob_store.h" #include "compression/compress.h" #include "compression/io.h" // Path #include "gemma/common.h" @@ -36,7 +37,7 @@ namespace gcpp { template struct TensorLoader { void operator()(ModelWeightsPtrs& weights, ForEachType fet, - CacheLoader& loader) { + ReadFromBlobStore& loader) { weights.ForEachTensor( {&weights}, fet, [&loader](const char* name, hwy::Span tensors) { @@ -52,7 +53,7 @@ BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type, HWY_ABORT("The model weights file '%s' does not exist.", weights.path.c_str()); } - CacheLoader loader(weights); + ReadFromBlobStore loader(weights); ForEachType fet = loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc; if (fet == ForEachType::kLoadWithToc) { diff --git a/ops/dot-inl.h b/ops/dot-inl.h index 012a956..cbb34f6 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -15,8 +15,6 @@ #include -#include - #include "compression/compress.h" #include "hwy/base.h" @@ -376,15 +374,6 @@ HWY_INLINE float Dot(const WT* HWY_RESTRICT w, const VT* vec, size_t num) { return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num); } -// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. -template -HWY_INLINE float Dot(const CompressedArray& w, size_t w_ofs, - const VT* vec_aligned, size_t num) { - const hn::ScalableTag d; - return w.scale() * - Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num); -} - // Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. template HWY_INLINE float Dot(const MatPtrT& w, size_t w_ofs, diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index b2678a3..20eb916 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -470,7 +470,7 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, env.Pool().Run( 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { // TODO: when using PerClusterPool, compute lp from outer and inner. - float* HWY_RESTRICT buf = env.Buf(thread); + float* HWY_RESTRICT buf = env.Buf().Batch(thread); const size_t tx = idx_tile % tilesX; const size_t ty = idx_tile / tilesX; const size_t row_ac = ty * kRegRows; diff --git a/ops/matmul.h b/ops/matmul.h index 34851f5..c643062 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -18,12 +18,15 @@ #include -#include "util/allocator.h" // RowVectorBatch +// IWYU pragma: begin_exports #include "util/threading.h" -#include "hwy/aligned_allocator.h" // IWYU pragma: export +#include "hwy/aligned_allocator.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: export -#include "hwy/per_target.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +// IWYU pragma: end_exports + +#include "util/allocator.h" // RowVectorBatch +#include "hwy/per_target.h" // VectorBytes namespace gcpp { @@ -80,19 +83,18 @@ Mat ConstMat(const T* HWY_RESTRICT ptr, size_t cols) { class MatMulEnv { public: MatMulEnv() : pools_(nullptr) {} - explicit MatMulEnv(PerClusterPools& pools) : pools_(&pools) { - const size_t num_lp = pools.NumLP(); - const size_t NF = hwy::VectorBytes() / sizeof(float); - buf_ = RowVectorBatch(num_lp, 16 * NF); + explicit MatMulEnv(NestedPools& pools) : pools_(&pools) { + const size_t N = hwy::VectorBytes() / sizeof(float); + buf_ = RowVectorBatch(pools.MaxWorkers(), 16 * N); } - float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); } - PerClusterPools& Pools() const { return *pools_; } - hwy::ThreadPool& Pool() const { return pools_->Inner(0); } + RowVectorBatch& Buf() { return buf_; } + NestedPools& Pools() const { return *pools_; } + hwy::ThreadPool& Pool() const { return pools_->Pool(); } private: RowVectorBatch buf_; - PerClusterPools* pools_; + NestedPools* pools_; }; } // namespace gcpp diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index b6445b3..3b6c7bd 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -14,9 +14,14 @@ // limitations under the License. #ifndef HWY_DISABLED_TARGETS -// Exclude HWY_SCALAR due to 2x bf16 -> f32. +// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require +// double-precision support. +#if HWY_ARCH_ARM_V7 +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON) +#else #define HWY_DISABLED_TARGETS HWY_SCALAR #endif +#endif #include "ops/matmul.h" @@ -26,8 +31,8 @@ #include #include "compression/compress.h" +#include "util/allocator.h" #include "util/threading.h" -#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" @@ -165,7 +170,7 @@ template HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, const MatTA* HWY_RESTRICT a, const MatTB* HWY_RESTRICT b_trans, const float scale, - const float* HWY_RESTRICT add, MatMulEnv& env, + const float* HWY_RESTRICT add_row, MatMulEnv& env, float* HWY_RESTRICT out) { // MatTA can be any Packed except NuqStream because it uses pointer // arithmetic, because it is the second argument to Dot, which does not @@ -176,23 +181,21 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, const PackedSpan b_span = MakeSpan(b_trans, cols_a_rows_b * cols_bc); - env.Pools().Outer().Run( - 0, rows_ac, [&](const uint64_t i, size_t o_thread) HWY_ATTR { - hwy::ThreadPool& inner = env.Pools().Inner(o_thread); - if (add != nullptr) { - inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) { - out[i * cols_bc + j] = - scale * Dot(df, b_span, j * cols_a_rows_b, - a + i * cols_a_rows_b, cols_a_rows_b) + - add[j]; - }); - } else { - inner.Run(0, cols_bc, [&](const uint64_t j, size_t i_thread) { - out[i * cols_bc + j] = - scale * Dot(df, b_span, j * cols_a_rows_b, - a + i * cols_a_rows_b, cols_a_rows_b); - }); - } + StaticPartitionRowsAndCols( + env.Pools(), rows_ac, cols_bc, sizeof(MatTB), + [&](size_t /*node*/, hwy::ThreadPool& pool, + const size_t /*worker_offset*/, const size_t row_begin, + const size_t row_end, const size_t col_begin, const size_t col_end) { + pool.Run(row_begin, row_end, + [&](const uint64_t row, size_t /*thread*/) { + for (size_t col = col_begin; col < col_end; ++col) { + const float add = add_row ? add_row[col] : 0.0f; + out[row * cols_bc + col] = + scale * Dot(df, b_span, col * cols_a_rows_b, + a + row * cols_a_rows_b, cols_a_rows_b) + + add; + } + }); }); } @@ -261,9 +264,10 @@ void TestAllMatMul() { return; } - PerClusterPools pools(/*max_clusters=*/1, /*max_threads=*/4, /*pin=*/1); - MatMulEnv env(pools); + NestedPools pools(4, /*pin=*/1); pools.StartSpinning(); + Allocator::Init(pools.Topology()); + MatMulEnv env(pools); using F32 = float; using SFP = SfpStream; diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index d9fdeeb..95f7f9e 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -45,6 +45,15 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. +template +HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned, + size_t num) { + const hn::ScalableTag d; + return w.scale() * Dot(d, MakeConstSpan(w.data(), w.NumElements()), w_ofs, + vec_aligned, num); +} + // Simple version without tiling nor threading, but two offsets/outputs and // always with addition. template diff --git a/util/allocator.cc b/util/allocator.cc new file mode 100644 index 0000000..a7d2352 --- /dev/null +++ b/util/allocator.cc @@ -0,0 +1,183 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/allocator.h" + +#include + +#include + +#include "util/basics.h" // MaybeCheckInitialized + +#if GEMMA_NUMA +#if HWY_OS_WIN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#elif HWY_OS_LINUX +#include + +#include +#endif // HWY_OS_* +#endif // GEMMA_NUMA + +namespace gcpp { + +/*static*/ size_t Allocator::bytes_per_page_; +/*static*/ bool Allocator::use_numa_; +/*static*/ size_t Allocator::alignment_; + +/*static*/ size_t Allocator::DetectPageSize() { +#if HWY_OS_WIN + SYSTEM_INFO sys_info; + GetSystemInfo(&sys_info); + return sys_info.dwPageSize; +#elif HWY_OS_LINUX + return sysconf(_SC_PAGESIZE); +#else + return 0; +#endif +} + +#if GEMMA_NUMA && HWY_OS_LINUX + +using Ret = long; // NOLINT(runtime/int) +using UL = unsigned long; // NOLINT(runtime/int) +static constexpr size_t ULBits = sizeof(UL) * 8; + +// Calling via syscall avoids a dependency on libnuma. +struct SyscallWrappers { + static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes, + unsigned flags) { + MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL)); + return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags); + }; + + static Ret move_pages(int pid, UL count, void** pages, const int* nodes, + int* status, int flags) { + MaybeCheckInitialized(pages, count * sizeof(void*)); + MaybeCheckInitialized(nodes, count * sizeof(int)); + MaybeCheckInitialized(status, count * sizeof(int)); + return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags); + } +}; + +size_t CountBusyPages(size_t num_pages, size_t node, void** pages, + const int* status) { + // Return value 0 does not actually guarantee all pages were moved. + size_t num_busy = 0; + for (size_t i = 0; i < num_pages; ++i) { + if (status[i] == -EBUSY) { + ++num_busy; + // Touch + hwy::ZeroBytes(pages[i], 8); + } else if (status[i] != static_cast(node)) { + fprintf(stderr, "Error %d moving pages[%zu]=%p to node %zu (errno %d)\n", + status[i], i, pages[i], node, errno); + } + } + return num_busy; +} + +// Attempts to move(!) memory to the given NUMA node, typically obtained from +// `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. Using `mbind` +// directly is easier than calling libnuma's `numa_move_pages`, which requires +// an array of pages. Note that `numa_tonode_memory` is insufficient because +// it does not specify the `MPOL_MF_MOVE` flag, so it only sets the policy, +// which means it would have to be called before pages are faulted in, but +// `aligned_allocator.h` modifies the first bytes for its bookkeeping. +// May overwrite some of the memory with zeros. +static void BindMemory(void* ptr, size_t bytes, size_t node) { + constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" + // Avoid mbind because it does not report why it failed, which is most likely + // because pages are busy, in which case we want to know which. +#if 0 + // nodemask with only the given node set. + UL nodes[hwy::DivCeil(kMaxNodes, ULBits)] = {}; + nodes[node / ULBits] = 1ULL << (node % ULBits); + + const int mode = 2; // MPOL_BIND + const unsigned flags = 3; // MPOL_MF_MOVE | MPOL_MF_STRICT + const int ret = + SyscallWrappers::mbind(ptr, bytes, mode, nodes, kMaxNodes, flags); + if (ret != 0) { + fprintf(stderr, "Failed to bind %p %zu to node %zu (errno %d)\n", ptr, + bytes, node, errno); + } +#elif 1 + const unsigned flags = 2; // MPOL_MF_MOVE + const size_t bytes_per_page = static_cast(sysconf(_SC_PAGESIZE)); + HWY_ASSERT(bytes % bytes_per_page == 0); + const size_t num_pages = bytes / bytes_per_page; + std::vector pages; + pages.reserve(num_pages); + for (size_t i = 0; i < num_pages; ++i) { + pages.push_back(static_cast(ptr) + i * bytes_per_page); + } + std::vector nodes(num_pages, node); + std::vector status(num_pages, static_cast(kMaxNodes)); + Ret ret = SyscallWrappers::move_pages( + /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); + size_t num_busy = + CountBusyPages(num_pages, node, pages.data(), status.data()); + if (num_busy != 0) { + // Try again + ret = SyscallWrappers::move_pages( + /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); + const size_t num_busy_before = num_busy; + num_busy = CountBusyPages(num_pages, node, pages.data(), status.data()); + fprintf( + stderr, + "second try still %zu busy, was %zu. 2nd ret %d status %d %d %d %d\n", + num_busy, num_busy_before, static_cast(ret), status[0], status[1], + status[2], status[3]); + } + + if (ret < 0) { + fprintf(stderr, + "Failed to bind %p %zu to node %zu (errno %d) status %d %d\n", ptr, + bytes, node, errno, status[0], status[1]); + } +#endif +} + +#else +// TODO: support other OSes. +static void BindMemory(void*, size_t, size_t) {} +#endif // GEMMA_NUMA && HWY_OS_LINUX + +void BindTensor(NestedPools& nested, size_t rows, size_t cols, + size_t bytes_per_col, void* ptr) { + if (!Allocator::UseNUMA()) return; + uint8_t* p8 = static_cast(ptr); + const size_t bytes_per_row = cols * bytes_per_col; + StaticPartitionRowsAndCols( + nested, rows, cols, bytes_per_col, + [&](size_t node, hwy::ThreadPool&, const size_t /*worker_offset*/, + const size_t row_begin, const size_t row_end, const size_t col_begin, + const size_t col_end) { + for (size_t row = row_begin; row < row_end; ++row) { + uint8_t* slice = p8 + row * bytes_per_row + col_begin * bytes_per_col; + const size_t slice_size = (col_end - col_begin) * bytes_per_col; + BindMemory(slice, slice_size, node); + } + }); +} + +} // namespace gcpp diff --git a/util/allocator.h b/util/allocator.h index 821268c..386b23f 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -19,8 +19,29 @@ #include #include -#include "hwy/aligned_allocator.h" // IWYU pragma: export +#include // std::aligned_alloc + +// IWYU pragma: begin_exports +#include "util/threading.h" +#include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +// IWYU pragma: end_exports + +#ifndef GEMMA_NUMA +// The check below requires two #if, hence start with 0 and redefine to 1. +#define GEMMA_NUMA 0 + +// To avoid a dependency on libnuma, use syscalls directly. We require six +// arguments, which has been supported by glibc since around 2010. +#if defined(__GLIBC__) && defined(__GLIBC_PREREQ) +#if HWY_OS_LINUX && __GLIBC_PREREQ(2, 11) +#undef GEMMA_NUMA +#define GEMMA_NUMA 1 +#endif +#endif + +#endif // GEMMA_NUMA namespace gcpp { @@ -74,6 +95,136 @@ class RowVectorBatch { size_t len_; // columns in the matrix = vector length }; +// Stateful in order to know whether to bind to NUMA nodes. `Monostate` for +// convenience - avoids passing around a reference. +class Allocator { + public: + static void Init(const BoundedTopology& topology) { + bytes_per_page_ = DetectPageSize(); + HWY_ASSERT(bytes_per_page_ <= (4 << 20)); + + // NUMA only makes sense if: + // - the page size is known and 'reasonably small', preferably less than + // a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB. + // - we successfully detected topology and there are multiple nodes; + // - there are multiple packages, because we shard by package_idx. + use_numa_ = (bytes_per_page_ != 0 && bytes_per_page_ <= 16 * 1024) && + topology.NumNodes() > 1 && topology.NumPackages() > 1; + // TODO: remove once tensors are page-aligned. + use_numa_ = false; + fprintf(stderr, "Warning: disabling use_numa_\n"); + + alignment_ = use_numa_ ? bytes_per_page_ : HWY_ALIGNMENT; + } + + static bool UseNUMA() { return use_numa_; } + + // BindTensor requires row pointers and lengths be a multiple of this. + static size_t Alignment() { return alignment_; } + + template + static hwy::AlignedFreeUniquePtr Alloc(size_t num) { + // For non-NUMA, use the Highway allocator because it defends against 2k + // aliasing. + if (!use_numa_) return hwy::AllocateAligned(num); + + constexpr size_t kSize = sizeof(T); + // Ensure the `bytes = num * kSize` computation did not overflow. + constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; + constexpr size_t kBits = hwy::detail::ShiftCount(kSize); + static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); + const size_t bytes = kIsPow2 ? num << kBits : num * kSize; + const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; + if (check != num) { + return hwy::AlignedFreeUniquePtr(); // overflowed + } + + // AlignedFreeUniquePtr has a deleter that can call an arbitrary `free`, but + // with an extra opaque pointer, which we discard via this adapter. + const auto call_free = [](void* ptr, void*) { std::free(ptr); }; +#if !defined(__ANDROID_API__) || __ANDROID_API__ >= 28 + T* p = static_cast(std::aligned_alloc(Alignment(), bytes)); +#else + void* mem = nullptr; + int err = posix_memalign(&mem, Alignment(), bytes); + HWY_ASSERT(err == 0); + T* p = static_cast(mem); +#endif + return hwy::AlignedFreeUniquePtr( + p, hwy::AlignedFreer(call_free, nullptr)); + } + + private: + static size_t DetectPageSize(); + + // Required for BindMemory. Usually 4K, but can differ on Arm. + static size_t bytes_per_page_; + static bool use_numa_; + static size_t alignment_; +}; + +// Used in MatMul and allocator.h. Defined here because it depends on +// Allocator::Alignment(). +template +void StaticPartitionRowsAndCols(NestedPools& nested, size_t rows, size_t cols, + size_t bytes_per_element, const Func& func) { + // Both rows and cols must be a multiple of the alignment to avoid + // touching remote pages. + const size_t multiple = Allocator::Alignment() / bytes_per_element; + + // Static partitioning of columns across packages. We assume that column + // sharding is more expensive, hence we distribute columns across packages, + // of which there are usually only one or two. For MatMul, the final result is + // the sum of each package's partial dot products. + hwy::ThreadPool& all_packages = nested.AllPackages(); + const size_t num_packages = all_packages.NumWorkers(); + const size_t cols_per_package = + hwy::RoundUpTo(hwy::DivCeil(cols, num_packages), multiple); + const size_t col_tasks = hwy::DivCeil(cols, cols_per_package); + HWY_ASSERT(col_tasks <= num_packages); + all_packages.Run( + 0, col_tasks, [&](uint64_t package_idx, size_t package_thread) { + HWY_ASSERT(package_idx == package_thread); // one task per worker + const size_t col_begin = package_idx * cols_per_package; + const size_t col_end = HWY_MIN(col_begin + cols_per_package, cols); + + // Static partitioning of rows across the package's clusters. We assume + // that row sharding is cheaper. In MatMul, results can indeed be + // computed independently for each row of B. + hwy::ThreadPool& all_clusters = nested.AllClusters(package_idx); + const size_t num_clusters = all_clusters.NumWorkers(); + const size_t rows_per_cluster = + hwy::RoundUpTo(hwy::DivCeil(rows, num_clusters), multiple); + const size_t row_tasks = hwy::DivCeil(rows, rows_per_cluster); + HWY_ASSERT(row_tasks <= num_clusters); + all_clusters.Run( + 0, row_tasks, [&](uint64_t cluster_idx, size_t cluster_thread) { + HWY_ASSERT(cluster_idx == cluster_thread); // one task per worker + + // For binding to NUMA node. + const size_t node = nested.Node(package_idx, cluster_idx); + // Older CPUs that predate chiplets typically have only one + // cluster, so callers should also parallelize using this + // per-cluster pool. + hwy::ThreadPool& cluster = + nested.Cluster(package_idx, cluster_idx); + // This plus the worker from `cluster->Run` is the TLS index. + const size_t worker_offset = + nested.WorkerOffset(package_idx, cluster_idx); + + const size_t row_begin = cluster_idx * rows_per_cluster; + const size_t row_end = + HWY_MIN(row_begin + rows_per_cluster, rows); + + func(node, cluster, worker_offset, row_begin, row_end, col_begin, + col_end); + }); + }); +} + +void BindTensor(NestedPools& nested, size_t rows, size_t cols, + size_t bytes_per_col, void* ptr); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_UTIL_ALLOCATOR_H_ diff --git a/util/app.h b/util/app.h index c0a2d91..bf7dc27 100644 --- a/util/app.h +++ b/util/app.h @@ -194,13 +194,12 @@ struct LoaderArgs : public ArgsBase { ModelInfo info_; }; -static inline Gemma CreateGemma(const LoaderArgs& loader, - PerClusterPools& pools) { +static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) { return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools); } static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, - PerClusterPools& pools) { + NestedPools& pools) { return std::make_unique(loader.tokenizer, loader.weights, loader.Info(), pools); } diff --git a/util/threading.h b/util/threading.h index bf26ca0..bc22579 100644 --- a/util/threading.h +++ b/util/threading.h @@ -13,8 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Shared between various frontends. - #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_THREADING_H_ @@ -32,252 +30,45 @@ namespace gcpp { -// DEPRECATED, will be replaced by NestedPools once MatMul is updated. -// Owns 'inner' thread pools, one per 'cluster' (CCX or socket), plus an -// 'outer' thread pool with one worker per cluster. -// -// Useful for hierarchical parallelism, which makes sense when there are few -// but large tasks which should be parallelized by workers sharing a cache. -// This also implies lower latency for barrier synchronization of those workers. -class PerClusterPools { - using LPS = hwy::LogicalProcessorSet; - - static inline std::vector CoresInLPS(const LPS& cluster) { - std::vector cores; - cores.reserve(cluster.Count()); - cluster.Foreach([&cores](size_t idx) { cores.push_back(idx); }); - return cores; - } - - using CoreBitSets = std::vector; - - // Returns empty if detection failed. - CoreBitSets DetectCoresPerCluster() { - CoreBitSets clusters; - if (!have_threading_support_) return clusters; - - // Which processors are not disabled via OS, taskset, or numactl. - LPS enabled; - // If we don't know, better to abort rather than risk oversubscribing. - if (!GetThreadAffinity(enabled)) return clusters; - - hwy::Topology topology; - if (topology.packages.empty()) return clusters; - - // Merge all clusters into one set, as a stopgap to emulate gemma-inl's - // prior single pool. - // TODO: remove once MatMul supports hierarchical parallelism. - LPS all; - - // For each cluster, add its enabled *cores*. - for (const hwy::Topology::Package& package : topology.packages) { - for (const hwy::Topology::Cluster& cluster : package.clusters) { - cluster.lps.Foreach([&](size_t lp) { - if (enabled.Get(lp) && topology.lps[lp].smt == 0) { - all.Set(lp); - } - }); - } - - /* code to reinstate: - for (const hwy::Topology::Cluster& cluster : package.clusters) { - // Only use enabled *cores*, and only add if not empty. - cluster.lps.Foreach([&](size_t lp) { - if (enabled.Get(lp) && topology.lps[lp].smt == 0) { - all.Set(lp); - } - }); - if (lps.Any()) clusters.push_back(lps); - } - */ - } - if (all.Any()) clusters.push_back(all); - - // Sort by descending number of enabled cores, so that we preferentially - // use the largest clusters. - std::sort(clusters.begin(), clusters.end(), - [](const LPS& a, const LPS& b) { return a.Count() > b.Count(); }); - - return clusters; - } - - void SetWaitMode(hwy::PoolWaitMode wait_mode) { - outer_pool_.SetWaitMode(wait_mode); - for (auto& inner : inner_pools_) { - inner->SetWaitMode(wait_mode); - } - } - - // `user_max_or_zero` == 0 means no limit, which is the case for the defaults - // of `AppArgs` `max_clusters` and `num_threads`. - static inline size_t CapIfNonZero(size_t num_workers, - size_t user_max_or_zero) { - return (user_max_or_zero == 0) ? num_workers - : HWY_MIN(num_workers, user_max_or_zero); - } - - // Returns the number of threads for `ThreadPool` to create: zero if there is - // no threading support, otherwise the capped number of workers minus the - // caller of `ThreadPool::Run`, which is the outer worker or main thread. - size_t CappedNumThreads(size_t num_workers, size_t user_max_or_zero) const { - if (!have_threading_support_) return 0; - const size_t capped_num_workers = - CapIfNonZero(num_workers, user_max_or_zero); - // Avoid underflow if number of workers is zero. - return capped_num_workers == 0 ? 0 : capped_num_workers - 1; - } - - // Returns the number of workers for the inner pool whose index is `outer`, or - // 0 to indicate no limit if `max_threads` is zero. - size_t MaxInnerWorkers(const size_t max_threads, const size_t outer_workers, - const size_t outer) const { - HWY_DASSERT(outer < outer_workers); - if (max_threads == 0) return 0; // no limit - // Round down so we do not exceed the max. - const size_t max_threads_per_outer = max_threads / outer_workers; - // First outer pool gets the remainder. - const size_t remainder = (outer == 0) ? (max_threads % outer_workers) : 0; - return 1 + max_threads_per_outer + remainder; - } - - public: - // Move-only. - PerClusterPools() = delete; - PerClusterPools(const PerClusterPools&) = delete; - PerClusterPools& operator=(const PerClusterPools&) = delete; - PerClusterPools(PerClusterPools&&) = delete; - PerClusterPools& operator=(PerClusterPools&&) = delete; - - // PerClusterPools supports spin waits (see StartSpinning below). To prevent - // drastic slowdowns caused by excessive user-specified thread counts, which - // result in threads not running on their own core, we only allow for - // *upper bounds* on the number of clusters and threads. The actual number of - // clusters and threads are still limited by the detected topology. - // `max_threads` is the upper bound on threads to distribute among clusters, - // not including the one outer thread per cluster. - // - // `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically. - PerClusterPools(size_t max_clusters, size_t max_threads, int pin = -1) - : have_threading_support_(hwy::HaveThreadingSupport()), - cores_per_cluster_(DetectCoresPerCluster()), - outer_pool_(CappedNumThreads(cores_per_cluster_.size(), max_clusters)) { - // Topology detection failed - it currently requires Linux. - if (cores_per_cluster_.empty()) { - // Create a single inner pool with up to TotalLogicalProcessors() / 2 - // workers, further limited by `max_threads` if nonzero, and then pin to - // the first N processors, which are typically on the first socket. - const size_t num_threads = - CappedNumThreads(hwy::TotalLogicalProcessors() / 2, max_threads); - if (pin == -1) pin = num_threads > 8; - fprintf(stderr, "CPU topology unknown, using %zu threads, pin %d\n", - num_threads, pin); - inner_pools_.push_back(std::make_unique(num_threads)); - if (num_threads > 1 && pin) { - inner_pools_.back()->Run(0, num_threads, - [](uint64_t /*task*/, size_t thread) { - hwy::PinThreadToLogicalProcessor(thread); - }); - } - return; - } - - for (size_t outer = 0; outer < outer_pool_.NumWorkers(); ++outer) { - const size_t max_inner_workers = - MaxInnerWorkers(max_threads, outer_pool_.NumWorkers(), outer); - const size_t num_threads = CappedNumThreads( - cores_per_cluster_[outer].Count(), max_inner_workers); - inner_pools_.push_back(std::make_unique(num_threads)); - } - - if (pin == -1) { - pin = (outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers()) >= 12; - } - - if (pin) { - // For each inner pool, pin their threads AND the associated outer thread - // (the one calling inner.Run()) to the enabled cores in the cluster. - outer_pool_.Run( - 0, outer_pool_.NumWorkers(), - [this](uint64_t outer, size_t outer_thread) { - HWY_ASSERT(outer == outer_thread); // each outer has one task - hwy::ThreadPool& inner = *inner_pools_[outer]; - - const std::vector cores = - CoresInLPS(cores_per_cluster_[outer]); - // May have been capped by max_threads. - HWY_ASSERT(inner.NumWorkers() <= cores.size()); - - inner.Run(0, inner.NumWorkers(), - [&cores](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each inner has one task - hwy::PinThreadToLogicalProcessor(cores[task]); - }); - }); - } - } - - // Spinning reduces the latency of barrier synchronization, but wastes lots of - // energy for long waits, so only do it during generation. This might also be - // unsafe in virtualized environments because we require threads to be running - // on their own core and thus responsive to the barrier synchronization. - void StartSpinning() { SetWaitMode(hwy::PoolWaitMode::kSpin); } - void StopSpinning() { SetWaitMode(hwy::PoolWaitMode::kBlock); } - - // Bitset of cores, one per cluster, or empty if detection failed. Useful for - // displaying the topology. - const CoreBitSets& CoresPerCluster() const { return cores_per_cluster_; } - - hwy::ThreadPool& Outer() { return outer_pool_; } - hwy::ThreadPool& Inner(size_t outer) { - HWY_ASSERT(outer < Outer().NumWorkers()); - return *inner_pools_[outer]; - } - - // Returns number of logical processors, for allocating per-thread buffers. - size_t NumLP() const { - return outer_pool_.NumWorkers() * inner_pools_[0]->NumWorkers(); - } - - private: - bool have_threading_support_; - CoreBitSets cores_per_cluster_; - hwy::ThreadPool outer_pool_; - // hwy::ThreadPool is unfortunately not marked as movable, so we have to use - // unique_ptr. - std::vector> inner_pools_; -}; - // A slice of a 1D integer range such as the indices of packages or clusters. // This allows assigning them to multiple instances of our binary. -struct BoundedSlice { +class BoundedSlice { + public: // Defaults to "use all detected". - BoundedSlice(size_t skip = 0, size_t max = 0) : skip(skip), max(max) {} + BoundedSlice(size_t skip = 0, size_t max = 0) : skip_(skip), max_(max) {} - // How many to skip, or equivalently, index of the first to use. It is an - // error if this is >= `detected`, because that would leave none for this - // instance to use. - size_t skip; - - // Upper bound on the number to use, or zero if no limit. - size_t max; + size_t Begin() const { return skip_; } // STL-style one past the end. size_t End(size_t detected) const { - return (max == 0) ? detected : HWY_MIN(detected, skip + max); + return (max_ == 0) ? detected : HWY_MIN(detected, skip_ + max_); } // Number of elements in the slice. - size_t Num(size_t detected) const { return End(detected) - skip; } + size_t Num(size_t detected) const { return End(detected) - Begin(); } + + bool Contains(size_t detected, size_t idx) const { + return Begin() <= idx && idx < End(detected); + } template - void ForEach(const char* name, size_t detected, const Func& func) { - if (skip >= detected) { - HWY_ABORT("Invalid skip=%zu for %s, detected=%zu", skip, name, detected); + void Foreach(const char* name, size_t detected, const Func& func) { + if (Begin() >= detected) { + HWY_ABORT("Invalid skip=%zu for %s, detected=%zu", skip_, name, detected); } - for (size_t i = skip; i < End(detected); ++i) { + for (size_t i = Begin(); i < End(detected); ++i) { func(i); } } + + private: + // How many to skip, or equivalently, index of the first to use. It is an + // error if this is >= `detected`, because that would leave none for this + // instance to use. + size_t skip_; + + // Upper bound on the number to use, or zero if no limit. + size_t max_; }; // "LP" is a logical processor, a 0-based index passed to the OS. @@ -288,201 +79,255 @@ using LPS = hwy::LogicalProcessorSet; // NOTE: if topology is unknown or the OS affinity is too restrictive, we fall // back to a single package and cluster. class BoundedTopology { - // Sort packages/clusters by descending size so that users who only use one - // get the largest. - template - static void SortByDescendingLPs(std::vector& groups) { - std::sort(groups.begin(), groups.end(), [](const Group& a, const Group& b) { - return a.num_lps > b.num_lps; - }); - } - public: - struct Cluster { - // Simple version when topology is unknown. - explicit Cluster(size_t num_workers) : num_lps(num_workers) { - HWY_ASSERT(num_lps != 0); - } - - Cluster(const std::vector& all_lps, const LPS& enabled, - size_t package_lp, const hwy::Topology::Cluster& cluster, - LPS& package_lps) { - // All first-hyperthread LPs from the cluster that are enabled and not - // already in use as the package representative. - cluster.lps.Foreach([&](size_t lp) { - if (all_lps[lp].smt == 0 && enabled.Get(lp) && lp != package_lp) { - HWY_ASSERT(!lps.Get(lp)); - lps.Set(lp); - HWY_ASSERT(!package_lps.Get(lp)); - package_lps.Set(lp); - } - }); - num_lps = lps.Count(); // = 0 if all disabled. - } - - LPS lps; - size_t num_lps; - // Set by caller to the first of `lps` if there are multiple clusters in a - // package. - size_t cluster_lp = 0; - }; - - struct Package { - // Simple version when topology is unknown. - explicit Package(size_t num_workers) { - package_lp = 0; - num_lps = num_workers; - clusters.push_back(Cluster(num_workers)); - } - - Package(size_t package_idx, const hwy::Topology& topology, - const LPS& enabled, BoundedSlice cluster_slice) { - const hwy::Topology::Package& package = topology.packages[package_idx]; - package_lp = package.clusters[0].lps.First(); - cluster_slice.ForEach( - "cluster", package.clusters.size(), [&](size_t cluster_idx) { - Cluster cluster(topology.lps, enabled, package_lp, - package.clusters[cluster_idx], lps); - if (HWY_LIKELY(cluster.num_lps != 0)) { - num_lps += cluster.num_lps; // before std::move - clusters.push_back(std::move(cluster)); - } - }); - - // Note that it is possible for `clusters` to be empty if its LPs are all - // disabled. If so, the caller will ignore topology and create a single - // package and cluster. - - SortByDescendingLPs(clusters); - - // If there are multiple clusters, set their first LP to represent the - // cluster and mark them as unavailable for its pool. - if (clusters.size() > 1) { - for (Cluster& cluster : clusters) { - cluster.cluster_lp = cluster.lps.First(); - // Nonzero because if lp == 0 were enabled, it would be used as - // `package_lp` and excluded from `cluster.lps`. - HWY_ASSERT(cluster.cluster_lp != 0); - HWY_ASSERT(cluster.cluster_lp != package_lp); - cluster.lps.Clear(cluster.cluster_lp); - } - } - } - - size_t package_lp; - LPS lps; - size_t num_lps = 0; - std::vector clusters; - }; - BoundedTopology(BoundedSlice package_slice, BoundedSlice cluster_slice, BoundedSlice lp_slice) { - const bool have_threading_support = hwy::HaveThreadingSupport(); - LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl. - bool missing_cluster = false; - - if (HWY_LIKELY(have_threading_support && !topology_.packages.empty())) { - (void)GetThreadAffinity(enabled_lps); // failure = all disabled - - // No effect if topology is unknown or `enabled_lps` is empty. - package_slice.ForEach( - "package", topology_.packages.size(), [&](size_t package_idx) { - Package package(package_idx, topology_, enabled_lps, cluster_slice); - // Skip if empty - can happen due to `enabled_lps`. - if (HWY_LIKELY(!package.clusters.empty())) { - total_lps_ += package.num_lps; // before std::move - packages_.push_back(std::move(package)); - } - }); - - for (Package& package : packages_) { - missing_cluster = package.clusters.empty(); - if (HWY_UNLIKELY(missing_cluster)) { - fprintf( - stderr, - "Warning, found no clusters for package with %zu LPs.\nWe will " - "ignore topology and assume a single package/cluster.\n", - package.num_lps); - break; - } + // Regardless of topology, ignore LPs disabled via OS, taskset, or numactl. + LPS enabled_lps; + if (HWY_UNLIKELY(!GetThreadAffinity(enabled_lps))) { + const size_t num_lps = hwy::TotalLogicalProcessors(); + fprintf( + stderr, + "Warning, unknown OS affinity, considering all %zu LPs enabled\n.", + num_lps); + for (size_t lp = 0; lp < hwy::TotalLogicalProcessors(); ++lp) { + enabled_lps.Set(lp); } } - // Topology unknown or any package ended up empty: create a single package - // with one cluster. - if (HWY_UNLIKELY(packages_.empty() || missing_cluster)) { - // We do not bother to detect hyperthreads. Not all CPUs have two per - // core, so instead of dividing, rely on the user's `lp_slice.max`. This - // works because Linux groups LPs by HT. - const size_t num_lps = have_threading_support - ? lp_slice.Num(hwy::TotalLogicalProcessors()) - : 1; - packages_.clear(); - packages_.push_back(Package(num_lps)); - total_lps_ = num_lps; - snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", num_lps); - } else { - SortByDescendingLPs(packages_); - - const hwy::Topology::Package& tpackage0 = topology_.packages[0]; - HWY_ASSERT(!tpackage0.clusters.empty()); - const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0]; - const Package& package0 = GetPackage(0); - const Cluster& cluster0 = GetCluster(0, 0); - snprintf(topology_string_, sizeof(topology_string_), - "%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(), - tpackage0.clusters.size(), tcluster0.lps.Count(), - packages_.size(), package0.clusters.size(), cluster0.num_lps); + // Without threading support, only keep the first enabled LP; it might still + // make sense to pin the main thread. + if (HWY_UNLIKELY(!hwy::HaveThreadingSupport())) { + HWY_ASSERT(enabled_lps.Any()); + const size_t lp = enabled_lps.First(); + enabled_lps = LPS(); + enabled_lps.Set(lp); } - HWY_ASSERT(NumPackages() != 0); - for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) { - HWY_ASSERT(NumClusters(package_idx) != 0); + if (HWY_LIKELY(!topology_.packages.empty())) { + InitFromTopology(enabled_lps, package_slice, cluster_slice); } + + // Topology unknown or no packages with enabled LPs: create a single + // package with one cluster, and one node. + if (HWY_UNLIKELY(NumPackages() == 0)) { + InitFromSlice(enabled_lps, lp_slice); + } + + HWY_ASSERT(NumPackages() != 0 && NumClusters(0) != 0 && NumNodes() != 0); } - const char* TopologyString() const { return topology_string_; } - size_t NumPackages() const { return packages_.size(); } - const Package& GetPackage(size_t package_idx) const { - HWY_ASSERT(package_idx < NumPackages()); - return packages_[package_idx]; - } - Package& GetPackage(size_t package_idx) { - HWY_ASSERT(package_idx < NumPackages()); - return packages_[package_idx]; - } + const char* TopologyString() const { return topology_string_; } + size_t NumNodes() const { return nodes_.Count(); } + + class Cluster { + public: + // Topology is unknown, rely on OS affinity and user-specified slice. + Cluster(const LPS& enabled_lps, BoundedSlice lp_slice) { + // Interpret `lp_slice` as a slice of the 1-bits of `enabled_lps`, so + // we honor both the OS affinity and the user-specified slice. Note that + // this can be used to exclude hyperthreads because Linux groups LPs by + // sibling index. For example, the first `num_cores` are not siblings. + const size_t detected = enabled_lps.Count(); + size_t enabled_idx = 0; + enabled_lps.Foreach([&](size_t lp) { + if (lp_slice.Contains(detected, enabled_idx++)) { + AddLP(lp); + } + }); + + // lp_slice can only reduce the number of `enabled_lps`, and not below 1. + HWY_ASSERT(num_workers_ != 0); + } + + Cluster(const LPS& enabled_lps, + const std::vector& all_lps, + const hwy::Topology::Cluster& tcluster) { + bool is_first_lp = true; + + tcluster.lps.Foreach([&](size_t lp) { + // Skip if not first-hyperthread or disabled. + if (all_lps[lp].smt != 0 || !enabled_lps.Get(lp)) return; + + AddLP(lp); + + // Set `node` once, and ensure subsequent nodes match - we assume there + // is only one NUMA node per cluster. + const size_t lp_node = static_cast(all_lps[lp].node); + if (is_first_lp) { + is_first_lp = false; + node_ = lp_node; + } else { + static bool warned = false; + if (lp_node != node_ && !warned) { + warned = true; + fprintf(stderr, + "WARNING: lp %zu on node %zu != cluster node %zu.\n", lp, + lp_node, node_); + } + } + }); + } + + // For SortByDescendingSize. + size_t Size() const { return num_workers_; } + + // Returns vector with all enabled LPs, used for pinning. + std::vector LPVector() const { + std::vector lps; + lps.reserve(lps_.Count()); + lps_.Foreach([&lps](size_t lp) { lps.push_back(lp); }); + return lps; + } + + size_t Node() const { return node_; } + + private: + void AddLP(size_t lp) { + HWY_ASSERT(!lps_.Get(lp)); // Foreach ensures uniqueness + lps_.Set(lp); + ++num_workers_; + } + + // Enabled LPs; if topology is known, only the ones in this cluster. + LPS lps_; + // How many workers in the per-cluster pool. If 0, this Cluster is removed. + size_t num_workers_ = 0; + // NUMA node, set from hwy::Topology::LP::node. + size_t node_ = 0; + }; // Cluster size_t NumClusters(size_t package_idx) const { - return GetPackage(package_idx).clusters.size(); + HWY_ASSERT(package_idx < NumPackages()); + return packages_[package_idx].clusters.size(); } const Cluster& GetCluster(size_t package_idx, size_t cluster_idx) const { - const Package& package = GetPackage(package_idx); + HWY_ASSERT(package_idx < NumPackages()); + const Package& package = packages_[package_idx]; HWY_ASSERT(cluster_idx < package.clusters.size()); return package.clusters[cluster_idx]; } Cluster& GetCluster(size_t package_idx, size_t cluster_idx) { - Package& package = GetPackage(package_idx); + HWY_ASSERT(package_idx < NumPackages()); + Package& package = packages_[package_idx]; HWY_ASSERT(cluster_idx < package.clusters.size()); return package.clusters[cluster_idx]; } - // Returns number of logical processors, for allocating per-thread buffers. - size_t NumLP() const { return total_lps_; } + // Returns total number of cluster workers, for deciding whether to pin. + size_t TotalWorkers() const { + size_t total_workers = 0; + for (size_t package_idx = 0; package_idx < NumPackages(); ++package_idx) { + const size_t num_clusters = NumClusters(package_idx); + for (size_t cluster_idx = 0; cluster_idx < num_clusters; ++cluster_idx) { + total_workers += GetCluster(package_idx, cluster_idx).Size(); + } + } + return total_workers; + } private: + // Sort T := packages/clusters by descending 'size' so that users who only use + // one Group get the largest. + template + static void SortByDescendingSize(std::vector& groups) { + std::sort(groups.begin(), groups.end(), + [](const T& a, const T& b) { return a.Size() > b.Size(); }); + } + + struct Package { + // Topology is unknown, rely on OS affinity and user-specified slice. + Package(const LPS& enabled_lps, BoundedSlice lp_slice) { + clusters.push_back(Cluster(enabled_lps, lp_slice)); + } + + // NOTE: caller is responsible for checking whether `clusters` is empty. + Package(const LPS& enabled_lps, const hwy::Topology& topology, + size_t package_idx, BoundedSlice cluster_slice) { + const hwy::Topology::Package& tpackage = topology.packages[package_idx]; + // Populate `clusters` with the subset of clusters in `cluster_slice` that + // have any enabled LPs. If `clusters` remains empty, the caller will + // skip this `Package`. + clusters.reserve(cluster_slice.Num(tpackage.clusters.size())); + cluster_slice.Foreach( + "cluster", tpackage.clusters.size(), [&](size_t cluster_idx) { + const hwy::Topology::Cluster& tcluster = + tpackage.clusters[cluster_idx]; + Cluster cluster(enabled_lps, topology.lps, tcluster); + // Skip if empty, i.e. too few `enabled_lps`. + if (HWY_LIKELY(cluster.Size() != 0)) { + clusters.push_back(std::move(cluster)); + } + }); + SortByDescendingSize(clusters); + } + + // For SortByDescendingSize. + size_t Size() const { return clusters.size(); } + + std::vector clusters; + }; // Package + + // Main part of ctor, called when topology is known. + void InitFromTopology(const LPS& enabled_lps, BoundedSlice package_slice, + BoundedSlice cluster_slice) { + // (Possibly empty) subset of `Topology` packages that have `enabled_lps`. + package_slice.Foreach( + "package", topology_.packages.size(), [&](size_t package_idx) { + Package package(enabled_lps, topology_, package_idx, cluster_slice); + // Skip if empty, i.e. too few `enabled_lps`. + if (HWY_LIKELY(!package.clusters.empty())) { + packages_.push_back(std::move(package)); + } + }); + if (NumPackages() == 0) return; + SortByDescendingSize(packages_); + + const hwy::Topology::Package& tpackage0 = topology_.packages[0]; + HWY_ASSERT(!tpackage0.clusters.empty()); + const hwy::Topology::Cluster& tcluster0 = tpackage0.clusters[0]; + // GetCluster(0, 0) is valid because only non-empty Packages were kept. + snprintf(topology_string_, sizeof(topology_string_), + "%zux%zux%zu, using %zux%zux%zu", topology_.packages.size(), + tpackage0.clusters.size(), tcluster0.lps.Count(), packages_.size(), + NumClusters(0), GetCluster(0, 0).Size()); + + // Remember NUMA nodes of *enabled* LPs. + enabled_lps.Foreach([&](size_t lp) { + nodes_.Set(static_cast(topology_.lps[lp].node)); + }); + } + + void InitFromSlice(const LPS& enabled_lps, BoundedSlice lp_slice) { + packages_.push_back(Package(enabled_lps, lp_slice)); + + snprintf(topology_string_, sizeof(topology_string_), "LPs=%zu", + GetCluster(0, 0).Size()); + + // Assume a single NUMA node. + nodes_.Set(0); + HWY_ASSERT(NumNodes() == 1); + } + hwy::Topology topology_; - size_t total_lps_ = 0; std::vector packages_; char topology_string_[96]; + LPS nodes_; }; -// Creates a hierarchy of thread pools according to BoundedTopology: one with a -// thread per enabled package; for each of those, one with a thread per enabled -// cluster (CCX/shared L3), and for each of those, the remaining enabled cores -// in that cluster. The cores representing each package and cluster are not -// included in the per-cluster pool because we support spin-waiting, hence -// there should be at most one thread per HW core. +// Creates a hierarchy of thread pools according to `BoundedTopology`: one with +// a thread per enabled package; for each of those, one with a thread per +// enabled cluster (CCX/shared L3), and for each of those, the remaining +// enabled cores in that cluster. +// +// Note that we support spin waits, thus it is important for each thread to be +// responsive, hence we do not create more than one thread per enabled core. +// For example, when there are two packages with four clusters of 8 cores, +// `AllPackages` has the main thread plus one extra thread, each `AllClusters` +// has one of the `AllPackages` threads plus three extras, each `Cluster` runs +// on one `AllClusters` thread plus seven extra workers, for a total of +// 1 + 2*3 + 2*(4*7) = 63 extras plus the main thread. // // Useful when there are tasks which should be parallelized by workers sharing a // cache, or on the same NUMA node. In both cases, individual pools have lower @@ -498,14 +343,13 @@ class NestedPools { NestedPools& operator=(NestedPools&&) = delete; // `max_threads` is the maximum number of threads to divide among all - // clusters. It does not include the package and cluster representatives. - // This is more intuitive than a per-cluster limit for users who may not be - // aware of the CPU topology. + // clusters. This is more intuitive than a per-cluster limit for users who + // may not be aware of the CPU topology. // // To ensure we do not create more threads than there are HW cores, which - // would cause huge slowdowns when spinning, `BoundedSlice` imposes upper - // bounds on the number of detected packages and clusters rather than - // defining an exact amount. + // would cause huge slowdowns when spinning, the `BoundedSlice` arguments + // only impose upper bounds on the number of detected packages and clusters + // rather than defining the actual number of threads. // // `pin` is 0 or 1 to force enable/disable, or -1 to choose automatically. NestedPools(size_t max_threads, int pin = -1, @@ -513,12 +357,14 @@ class NestedPools { BoundedSlice cluster_slice = BoundedSlice(), BoundedSlice lp_slice = BoundedSlice()) : topology_(package_slice, cluster_slice, lp_slice) { - if (pin == -1) pin = topology_.NumLP() >= 12; + if (pin == -1) pin = topology_.TotalWorkers() >= 12; packages_.resize(topology_.NumPackages()); all_packages_ = MakePool(packages_.size()); const size_t max_workers_per_package = max_threads / packages_.size(); - // Parallel to ensure we also pin the calling (main) thread. + // Each worker in all_packages_, including the main thread, will be the + // calling thread of an all_clusters->Run, and hence pinned to one of the + // `cluster.lps` if `pin`. all_packages_->Run( 0, all_packages_->NumWorkers(), [&](uint64_t package_idx, size_t thread) { @@ -526,10 +372,24 @@ class NestedPools { packages_[package_idx] = Package( topology_, package_idx, max_workers_per_package, pin, lp_slice); }); + + // For mapping package/cluster/thread to noncontiguous TLS indices, in case + // cluster/thread counts differ. + HWY_ASSERT(!packages_.empty() && packages_.size() <= 16); + for (const Package& p : packages_) { + max_clusters_per_package_ = + HWY_MAX(max_clusters_per_package_, p.NumClusters()); + max_workers_per_cluster_ = + HWY_MAX(max_workers_per_cluster_, p.MaxWorkersPerCluster()); + } + HWY_ASSERT(max_clusters_per_package_ >= 1); + HWY_ASSERT(max_clusters_per_package_ <= 64); + HWY_ASSERT(max_workers_per_cluster_ >= 1); + HWY_ASSERT(max_workers_per_cluster_ <= 256); } // Spinning reduces the latency of barrier synchronization, but wastes lots - // of energy for long waits, so only do it during generation. This might + // of energy for long waits, so only do it during generation. Spinning might // also be unsafe in virtualized environments because we require threads to // be running on their own core and thus responsive to the barrier // synchronization. @@ -538,18 +398,45 @@ class NestedPools { hwy::ThreadPool& AllPackages() { return *all_packages_; } hwy::ThreadPool& AllClusters(size_t package_idx) { - HWY_ASSERT(package_idx < AllPackages().NumWorkers()); - return *packages_[package_idx].all_clusters; + HWY_DASSERT(package_idx < packages_.size()); + return packages_[package_idx].AllClusters(); } hwy::ThreadPool& Cluster(size_t package_idx, size_t cluster_idx) { - HWY_ASSERT(cluster_idx < AllClusters(package_idx).NumWorkers()); - return *packages_[package_idx].clusters[cluster_idx]; + HWY_DASSERT(package_idx < packages_.size()); + return packages_[package_idx].Cluster(cluster_idx); } + // For binding to NUMA nodes. + size_t Node(size_t package_idx, size_t cluster_idx) const { + return topology_.GetCluster(package_idx, cluster_idx).Node(); + } + + // Reasonably tight upper bound for allocating thread-local storage (TLS). + size_t MaxWorkers() const { + return packages_.size() * max_clusters_per_package_ * + max_workers_per_cluster_; + } + // Returns the first of `cluster.NumWorkers()` TLS indices, to which callers + // add the worker index given by `cluster.Run`. + size_t WorkerOffset(size_t package_idx, size_t cluster_idx) const { + return (package_idx * max_clusters_per_package_ + cluster_idx) * + max_workers_per_cluster_; + } + + // For Allocator + const BoundedTopology& Topology() const { return topology_; } const char* TopologyString() const { return topology_.TopologyString(); } - // Returns number of logical processors, for allocating per-thread buffers. - size_t NumLP() const { return topology_.NumLP(); } + // Returns a single pool on the first package: either one thread per cluster + // if there is more than one, which maximizes available memory bandwidth, or + // the first cluster, which is typically the whole package. For use by callers + // that only parallelize over a 1D range, as opposed to the nested + // parallelism of `StaticPartitionRowsAndCols`. + hwy::ThreadPool& Pool() { + // Only one cluster: use its pool, typically a whole socket. + if (AllClusters(0).NumWorkers() == 1) return Cluster(0, 0); + return AllClusters(0); + } private: // `max_or_zero` == 0 means no limit. @@ -569,69 +456,72 @@ class NestedPools { } class Package { - static PoolPtr CreateClusterPool(const BoundedTopology::Cluster& cluster, - size_t max_cluster_workers, int pin, - BoundedSlice lp_slice) { - PoolPtr pool = - MakePool(CapIfNonZero(cluster.num_lps, max_cluster_workers)); - - if (!pin) return pool; - // Else: pin all new threads AND the calling thread from `all_clusters`. - - // We know the topology: pin to this cluster's cores, including the - // calling thread from `all_clusters`. - if (cluster.lps.Any()) { - std::vector lps; - lps.reserve(cluster.num_lps); - cluster.lps.Foreach([&lps](size_t lp) { lps.push_back(lp); }); - - pool->Run(0, pool->NumWorkers(), [&lps](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task - hwy::PinThreadToLogicalProcessor(lps[task]); - }); - } else { - // Pin to consecutive LPs. - pool->Run(0, pool->NumWorkers(), - [lp_slice](uint64_t task, size_t thread) { - HWY_ASSERT(task == thread); // each worker has one task - hwy::PinThreadToLogicalProcessor(lp_slice.skip + thread); - }); - } - return pool; - } - public: Package() = default; // for vector Package(const BoundedTopology& topology, size_t package_idx, size_t max_workers_per_package, int pin, BoundedSlice lp_slice) { - clusters.resize(topology.NumClusters(package_idx)); + // Pre-allocate because elements are set concurrently. + clusters_.resize(topology.NumClusters(package_idx)); const size_t max_workers_per_cluster = - max_workers_per_package / clusters.size(); + max_workers_per_package / clusters_.size(); - all_clusters = MakePool(clusters.size()); - // Parallel so we also pin the calling thread from `all_packages_`. - all_clusters->Run( - 0, all_clusters->NumWorkers(), + all_clusters_ = MakePool(clusters_.size()); + // Parallel so we also pin the calling worker in `all_clusters` to + // `cluster.lps`. + all_clusters_->Run( + 0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) { HWY_ASSERT(cluster_idx == thread); // each thread has one task const BoundedTopology::Cluster& cluster = topology.GetCluster(package_idx, cluster_idx); - clusters[cluster_idx] = CreateClusterPool( - cluster, max_workers_per_cluster, pin, lp_slice); + clusters_[cluster_idx] = + MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster)); + if (HWY_LIKELY(pin)) { + // Pin threads AND the calling thread from `all_clusters` to lps. + const std::vector lps = cluster.LPVector(); + HWY_ASSERT(clusters_[cluster_idx]->NumWorkers() <= lps.size()); + clusters_[cluster_idx]->Run( + 0, clusters_[cluster_idx]->NumWorkers(), + [&lps](uint64_t task, size_t thread) { + HWY_ASSERT(task == thread); // each worker has one task + hwy::PinThreadToLogicalProcessor(lps[task]); + }); + } }); } - std::vector clusters; - PoolPtr all_clusters; - }; + size_t NumClusters() const { return clusters_.size(); } + size_t MaxWorkersPerCluster() const { + size_t max_workers_per_cluster = 0; + for (const PoolPtr& cluster : clusters_) { + max_workers_per_cluster = + HWY_MAX(max_workers_per_cluster, cluster->NumWorkers()); + } + return max_workers_per_cluster; + } + + hwy::ThreadPool& AllClusters() { return *all_clusters_; } + hwy::ThreadPool& Cluster(size_t cluster_idx) { + HWY_DASSERT(cluster_idx < clusters_.size()); + return *clusters_[cluster_idx]; + } + + void SetWaitMode(hwy::PoolWaitMode wait_mode) { + all_clusters_->SetWaitMode(wait_mode); + for (PoolPtr& cluster : clusters_) { + cluster->SetWaitMode(wait_mode); + } + } + + private: + std::vector clusters_; + PoolPtr all_clusters_; + }; // Package void SetWaitMode(hwy::PoolWaitMode wait_mode) { all_packages_->SetWaitMode(wait_mode); for (Package& package : packages_) { - package.all_clusters->SetWaitMode(wait_mode); - for (PoolPtr& cluster : package.clusters) { - cluster->SetWaitMode(wait_mode); - } + package.SetWaitMode(wait_mode); } } @@ -639,12 +529,11 @@ class NestedPools { std::vector packages_; PoolPtr all_packages_; -}; -static inline NestedPools CreateSinglePool(size_t max_threads, int pin = -1) { - const BoundedSlice one(0, 1); - return NestedPools(max_threads, pin, one, one); -} + // For TLS indices. + size_t max_clusters_per_package_ = 0; + size_t max_workers_per_cluster_ = 0; +}; } // namespace gcpp diff --git a/util/threading_test.cc b/util/threading_test.cc index b5e8ff2..7f01f2c 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -35,27 +35,41 @@ TEST(ThreadingTest, TestBoundedSlice) { { BoundedSlice slice; std::vector expected; - slice.ForEach(name, 10, [&](size_t i) { expected.push_back(i); }); - EXPECT_EQ(10, slice.Num(10)); + const size_t detected = 10; + slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); }); + EXPECT_EQ(10, slice.Num(detected)); EXPECT_THAT(expected, ElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)); + EXPECT_TRUE(slice.Contains(detected, 0)); + EXPECT_TRUE(slice.Contains(detected, 9)); + EXPECT_FALSE(slice.Contains(detected, 10)); } // One arg: skip first N { BoundedSlice slice(3); std::vector expected; - slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); }); - EXPECT_EQ(6, slice.Num(9)); + const size_t detected = 9; + slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); }); + EXPECT_EQ(6, slice.Num(detected)); EXPECT_THAT(expected, ElementsAre(3, 4, 5, 6, 7, 8)); + EXPECT_FALSE(slice.Contains(detected, 2)); + EXPECT_TRUE(slice.Contains(detected, 3)); + EXPECT_TRUE(slice.Contains(detected, 8)); + EXPECT_FALSE(slice.Contains(detected, 9)); } // Both args: skip first N, then use at most M { BoundedSlice slice(3, 2); std::vector expected; - slice.ForEach(name, 9, [&](size_t i) { expected.push_back(i); }); - EXPECT_EQ(2, slice.Num(9)); + const size_t detected = 9; + slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); }); + EXPECT_EQ(2, slice.Num(detected)); EXPECT_THAT(expected, ElementsAre(3, 4)); + EXPECT_FALSE(slice.Contains(detected, 2)); + EXPECT_TRUE(slice.Contains(detected, 3)); + EXPECT_TRUE(slice.Contains(detected, 4)); + EXPECT_FALSE(slice.Contains(detected, 5)); } // Both args, but `max > detected - skip`: fewer than limit. Note that @@ -63,9 +77,13 @@ TEST(ThreadingTest, TestBoundedSlice) { { BoundedSlice slice(3, 2); std::vector expected; - slice.ForEach(name, 4, [&](size_t i) { expected.push_back(i); }); - EXPECT_EQ(1, slice.Num(4)); + const size_t detected = 4; + slice.Foreach(name, detected, [&](size_t i) { expected.push_back(i); }); + EXPECT_EQ(1, slice.Num(detected)); EXPECT_THAT(expected, ElementsAre(3)); + EXPECT_FALSE(slice.Contains(detected, 2)); + EXPECT_TRUE(slice.Contains(detected, 3)); + EXPECT_FALSE(slice.Contains(detected, 4)); } } @@ -76,8 +94,6 @@ TEST(ThreadingTest, TestBoundedTopology) { { BoundedTopology topology(all, all, all); fprintf(stderr, "%s\n", topology.TopologyString()); - ASSERT_NE(0, topology.NumPackages()); - ASSERT_NE(0, topology.NumClusters(0)); } // Max one package @@ -85,14 +101,12 @@ TEST(ThreadingTest, TestBoundedTopology) { BoundedTopology topology(one, all, all); fprintf(stderr, "%s\n", topology.TopologyString()); ASSERT_EQ(1, topology.NumPackages()); - ASSERT_NE(0, topology.NumClusters(0)); } // Max one cluster { BoundedTopology topology(all, one, all); fprintf(stderr, "%s\n", topology.TopologyString()); - ASSERT_NE(0, topology.NumPackages()); ASSERT_EQ(1, topology.NumClusters(0)); } }