From 799c264df38aff7a3f8dce786379820b82dfb40b Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 31 Jul 2025 08:44:47 -0700 Subject: [PATCH] Pre-tune thread pool before matmul Also improve profiler annotations - remove near-zero ones and add more for startup PiperOrigin-RevId: 789352414 --- BUILD.bazel | 2 ++ gemma/gemma.cc | 1 + gemma/model_store.cc | 5 +++ gemma/run.cc | 10 +++--- gemma/tokenizer.cc | 1 + gemma/weights.cc | 5 +-- io/BUILD.bazel | 1 + io/blob_store.cc | 10 +++--- io/blob_store.h | 2 +- paligemma/BUILD.bazel | 6 ---- util/threading_context.cc | 70 ++++++++++++++++++++++++++++++++++++++- 11 files changed, 93 insertions(+), 20 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 2628bc3..477b3f0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -93,6 +93,7 @@ cc_library( ":threading", ":topology", "@highway//:hwy", + "@highway//:hwy_test_util", "@highway//:profiler", ], ) @@ -205,6 +206,7 @@ cc_library( "//io:blob_store", "//io:fields", "@highway//:hwy", + "@highway//:profiler", "@highway//:thread_pool", ], ) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 496c21d..8c05306 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -609,6 +609,7 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), inference_(inference) { + // Negligible CPU time in the ctor body (except ReadFromBlobs). weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx); // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 8f6c138..ca454bd 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -36,6 +36,7 @@ #include "util/basics.h" #include "util/threading_context.h" #include "hwy/base.h" +#include "hwy/profiler.h" namespace gcpp { @@ -60,6 +61,8 @@ static void WarnIfExtra(const IFields::ReadResult& result, const char* name) { // Reads it from a blob or from a separate file if pre-2025. static std::string ReadTokenizer(BlobReader& reader, const Path& tokenizer_path) { + PROFILER_ZONE("Startup.ReadTokenizer"); + std::string tokenizer; // Check prevents `CallWithSpan` from printing a warning. if (reader.Find(kTokenizerName)) { @@ -306,6 +309,8 @@ bool ModelStore::ReadMatPtrs(BlobReader& reader) { // Check first to prevent `CallWithSpan` from printing a warning. if (!reader.Find(kMatPtrsName)) return false; + PROFILER_ZONE("Startup.ReadMatPtrs"); + // For verifying `config_.weight`. size_t min_bits = ~size_t{0}; Type weight_type = Type::kUnknown; diff --git a/gemma/run.cc b/gemma/run.cc index 7cbc4de..dd5165c 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -78,15 +78,14 @@ std::string GetPromptFromStream(std::istream& input, int verbosity, // Get prompt either from interactive input or command line std::string GetPrompt(const InferenceArgs& inference) { - PROFILER_ZONE("Gen.input"); // If prompt is provided via command line, use that - if (!inference.prompt.empty()) { - return inference.prompt; - } + if (!inference.prompt.empty()) return inference.prompt; if (!inference.prompt_file.Empty()) { + PROFILER_ZONE("Gen.ReadPrompt"); return ReadFileToString(inference.prompt_file); } + PROFILER_ZONE("Gen.input"); return GetPromptFromStream(std::cin, inference.verbosity, inference.eot_line); } @@ -299,8 +298,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, int main(int argc, char** argv) { gcpp::InternalInit(); { - PROFILER_ZONE("Startup.misc"); - + // Negligible CPU time. gcpp::LoaderArgs loader(argc, argv); gcpp::ThreadingArgs threading(argc, argv); gcpp::InferenceArgs inference(argc, argv); diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index 6e39f27..e0e071c 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -104,6 +104,7 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } +// Negligible CPU time in the ctor body. GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model) { sot_user_.reserve(3); diff --git a/gemma/weights.cc b/gemma/weights.cc index 3418acf..2f363e6 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -344,10 +344,9 @@ static void AllocateAndBindAll(std::vector& tensors, }); } -// Mode == kMap +// Mode == kMap. CPU time is negligible. static void MapAll(const std::vector& tensors, const MapPtr& mapped, uint64_t file_bytes) { - PROFILER_ZONE("Startup.Weights.Map"); for (size_t i = 0; i < tensors.size(); ++i) { // SetPtr does not change the stride, but it is expected to be packed // because that is what Compress() writes to the file. @@ -521,6 +520,8 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model, const InferenceArgs& inference, std::vector& mat_owners, ThreadingContext& ctx) { + PROFILER_ZONE("Startup.ReadFromBlobs"); + // List of tensors to read/map, and where from. std::vector tensors; diff --git a/io/BUILD.bazel b/io/BUILD.bazel index 1ca42a4..eb3a785 100644 --- a/io/BUILD.bazel +++ b/io/BUILD.bazel @@ -74,6 +74,7 @@ cc_library( "//:basics", "//:threading_context", "@highway//:hwy", + "@highway//:profiler", "@highway//:thread_pool", ], ) diff --git a/io/blob_store.cc b/io/blob_store.cc index c3fbea1..00d4fff 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -30,6 +30,7 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_compiler_arch.h" +#include "hwy/profiler.h" namespace gcpp { @@ -413,10 +414,11 @@ class BlobStore { std::vector directory_; // two per blob, see `SetRange`. }; // BlobStore -BlobReader::BlobReader(const Path& blob_path) - : blob_path_(blob_path), - file_(OpenFileOrAbort(blob_path, "r")), - file_bytes_(file_->FileSize()) { +BlobReader::BlobReader(const Path& blob_path) : blob_path_(blob_path) { + PROFILER_ZONE("Startup.BlobReader"); + + file_ = OpenFileOrAbort(blob_path, "r"); + file_bytes_ = file_->FileSize(); if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str()); BlobStore bs(*file_); diff --git a/io/blob_store.h b/io/blob_store.h index f7103d7..b7dba0f 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -104,7 +104,7 @@ class BlobReader { private: Path blob_path_; std::unique_ptr file_; - const uint64_t file_bytes_; + uint64_t file_bytes_; // const after ctor std::vector keys_; std::vector ranges_; diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index 5b5fdd4..ebed1ca 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -34,16 +34,13 @@ cc_library( srcs = ["paligemma_helper.cc"], hdrs = ["paligemma_helper.h"], deps = [ - ":image", "//:allocator", "//:benchmark_helper", "//:configs", "//:gemma_args", "//:gemma_lib", "//compression:types", - "//io", "@highway//:hwy", - "@highway//:profiler", ], ) @@ -59,15 +56,12 @@ cc_test( ], deps = [ ":paligemma_helper", - "//devtools/build/runtime:get_runfiles_dir", "@googletest//:gtest_main", # buildcleaner: keep "//:allocator", "//:benchmark_helper", "//:configs", "//:gemma_lib", - "//compression:types", "//io", - "@highway//:hwy", "@highway//:hwy_test_util", ], ) diff --git a/util/threading_context.cc b/util/threading_context.cc index 14eca95..3bb1080 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -15,13 +15,81 @@ #include "util/threading_context.h" +#include +#include + +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/profiler.h" +#include "hwy/tests/test_util.h" // RandomState + namespace gcpp { +// Invokes `pool.Run` with varying task counts until auto-tuning completes, or +// an upper bound just in case. +static void TunePool(hwy::ThreadPool& pool) { + const size_t num_workers = pool.NumWorkers(); + // pool.Run would just be a serial loop without auto-tuning, so skip. + if (num_workers == 1) return; + + // Random shuffle of task counts to defeat branch prediction. + const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1, + num_workers * 5, num_workers * 20}; + + // Count tasks executed to ensure workers aren't optimized out. One per + // cache line to avoid false sharing. + const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t); + + std::vector counters(num_workers * kSizePerLine); + size_t prev_total = 0; // avoids having to reset counters. + + hwy::RandomState rng; + for (size_t rep = 0; rep < 500; ++rep) { + if (HWY_UNLIKELY(pool.AutoTuneComplete())) { + break; + } + + const uint64_t r = hwy::Random64(&rng); + const size_t begin = r >> 2; + const size_t end = begin + num_tasks[r & 3]; + + pool.Run(begin, end, [&](uint64_t task, size_t thread) { + HWY_ASSERT(begin <= task && task < end); + HWY_ASSERT(thread < num_workers); + counters[thread * kSizePerLine]++; + }); + + // Reduce count and ensure it matches the expected number of tasks. + size_t total = 0; + for (size_t i = 0; i < num_workers; ++i) { + total += counters[i * kSizePerLine]; + } + const size_t expected = end - begin; + HWY_ASSERT(total == prev_total + expected); + prev_total += expected; + } +} + ThreadingContext::ThreadingContext(const ThreadingArgs& args) : topology(BoundedSlice(args.skip_packages, args.max_packages), BoundedSlice(args.skip_clusters, args.max_clusters), BoundedSlice(args.skip_lps, args.max_lps)), allocator(topology, args.bind != Tristate::kFalse), - pools(topology, allocator, args.max_threads, args.pin) {} + pools(topology, allocator, args.max_threads, args.pin) { + PROFILER_ZONE("Startup.ThreadingContext autotune"); + TunePool(pools.AllPackages()); + for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) { + hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx); + TunePool(clusters); + + // Run in parallel because Turin CPUs have 16, and in real usage, we often + // run all at the same time. + clusters.Run(0, clusters.NumWorkers(), + [&](uint64_t cluster_idx, size_t /*thread*/) { + TunePool(pools.Cluster(pkg_idx, cluster_idx)); + }); + } +} } // namespace gcpp