Pre-tune thread pool before matmul

Also improve profiler annotations - remove near-zero ones and add more for startup

PiperOrigin-RevId: 789352414
This commit is contained in:
Jan Wassenberg 2025-07-31 08:44:47 -07:00 committed by Copybara-Service
parent 50ee1a3e92
commit 799c264df3
11 changed files with 93 additions and 20 deletions

View File

@ -93,6 +93,7 @@ cc_library(
":threading",
":topology",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:profiler",
],
)
@ -205,6 +206,7 @@ cc_library(
"//io:blob_store",
"//io:fields",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

View File

@ -609,6 +609,7 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) {
// Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference,
mat_owners_, ctx);
// Read everything into memory, or `weights_.mapped_` keeps the mapping alive.

View File

@ -36,6 +36,7 @@
#include "util/basics.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
namespace gcpp {
@ -60,6 +61,8 @@ static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
// Reads it from a blob or from a separate file if pre-2025.
static std::string ReadTokenizer(BlobReader& reader,
const Path& tokenizer_path) {
PROFILER_ZONE("Startup.ReadTokenizer");
std::string tokenizer;
// Check prevents `CallWithSpan` from printing a warning.
if (reader.Find(kTokenizerName)) {
@ -306,6 +309,8 @@ bool ModelStore::ReadMatPtrs(BlobReader& reader) {
// Check first to prevent `CallWithSpan` from printing a warning.
if (!reader.Find(kMatPtrsName)) return false;
PROFILER_ZONE("Startup.ReadMatPtrs");
// For verifying `config_.weight`.
size_t min_bits = ~size_t{0};
Type weight_type = Type::kUnknown;

View File

@ -78,15 +78,14 @@ std::string GetPromptFromStream(std::istream& input, int verbosity,
// Get prompt either from interactive input or command line
std::string GetPrompt(const InferenceArgs& inference) {
PROFILER_ZONE("Gen.input");
// If prompt is provided via command line, use that
if (!inference.prompt.empty()) {
return inference.prompt;
}
if (!inference.prompt.empty()) return inference.prompt;
if (!inference.prompt_file.Empty()) {
PROFILER_ZONE("Gen.ReadPrompt");
return ReadFileToString(inference.prompt_file);
}
PROFILER_ZONE("Gen.input");
return GetPromptFromStream(std::cin, inference.verbosity, inference.eot_line);
}
@ -299,8 +298,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
int main(int argc, char** argv) {
gcpp::InternalInit();
{
PROFILER_ZONE("Startup.misc");
// Negligible CPU time.
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);

View File

@ -104,6 +104,7 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized);
}
// Negligible CPU time in the ctor body.
GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer,
Model model) {
sot_user_.reserve(3);

View File

@ -344,10 +344,9 @@ static void AllocateAndBindAll(std::vector<TensorToRead>& tensors,
});
}
// Mode == kMap
// Mode == kMap. CPU time is negligible.
static void MapAll(const std::vector<TensorToRead>& tensors,
const MapPtr& mapped, uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.Map");
for (size_t i = 0; i < tensors.size(); ++i) {
// SetPtr does not change the stride, but it is expected to be packed
// because that is what Compress() writes to the file.
@ -521,6 +520,8 @@ WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model,
const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners,
ThreadingContext& ctx) {
PROFILER_ZONE("Startup.ReadFromBlobs");
// List of tensors to read/map, and where from.
std::vector<TensorToRead> tensors;

View File

@ -74,6 +74,7 @@ cc_library(
"//:basics",
"//:threading_context",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

View File

@ -30,6 +30,7 @@
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
#include "hwy/profiler.h"
namespace gcpp {
@ -413,10 +414,11 @@ class BlobStore {
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
}; // BlobStore
BlobReader::BlobReader(const Path& blob_path)
: blob_path_(blob_path),
file_(OpenFileOrAbort(blob_path, "r")),
file_bytes_(file_->FileSize()) {
BlobReader::BlobReader(const Path& blob_path) : blob_path_(blob_path) {
PROFILER_ZONE("Startup.BlobReader");
file_ = OpenFileOrAbort(blob_path, "r");
file_bytes_ = file_->FileSize();
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
BlobStore bs(*file_);

View File

@ -104,7 +104,7 @@ class BlobReader {
private:
Path blob_path_;
std::unique_ptr<File> file_;
const uint64_t file_bytes_;
uint64_t file_bytes_; // const after ctor
std::vector<std::string> keys_;
std::vector<BlobRange> ranges_;

View File

@ -34,16 +34,13 @@ cc_library(
srcs = ["paligemma_helper.cc"],
hdrs = ["paligemma_helper.h"],
deps = [
":image",
"//:allocator",
"//:benchmark_helper",
"//:configs",
"//:gemma_args",
"//:gemma_lib",
"//compression:types",
"//io",
"@highway//:hwy",
"@highway//:profiler",
],
)
@ -59,15 +56,12 @@ cc_test(
],
deps = [
":paligemma_helper",
"//devtools/build/runtime:get_runfiles_dir",
"@googletest//:gtest_main", # buildcleaner: keep
"//:allocator",
"//:benchmark_helper",
"//:configs",
"//:gemma_lib",
"//compression:types",
"//io",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)

View File

@ -15,13 +15,81 @@
#include "util/threading_context.h"
#include <stddef.h>
#include <stdint.h>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/profiler.h"
#include "hwy/tests/test_util.h" // RandomState
namespace gcpp {
// Invokes `pool.Run` with varying task counts until auto-tuning completes, or
// an upper bound just in case.
static void TunePool(hwy::ThreadPool& pool) {
const size_t num_workers = pool.NumWorkers();
// pool.Run would just be a serial loop without auto-tuning, so skip.
if (num_workers == 1) return;
// Random shuffle of task counts to defeat branch prediction.
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
num_workers * 5, num_workers * 20};
// Count tasks executed to ensure workers aren't optimized out. One per
// cache line to avoid false sharing.
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t);
std::vector<size_t> counters(num_workers * kSizePerLine);
size_t prev_total = 0; // avoids having to reset counters.
hwy::RandomState rng;
for (size_t rep = 0; rep < 500; ++rep) {
if (HWY_UNLIKELY(pool.AutoTuneComplete())) {
break;
}
const uint64_t r = hwy::Random64(&rng);
const size_t begin = r >> 2;
const size_t end = begin + num_tasks[r & 3];
pool.Run(begin, end, [&](uint64_t task, size_t thread) {
HWY_ASSERT(begin <= task && task < end);
HWY_ASSERT(thread < num_workers);
counters[thread * kSizePerLine]++;
});
// Reduce count and ensure it matches the expected number of tasks.
size_t total = 0;
for (size_t i = 0; i < num_workers; ++i) {
total += counters[i * kSizePerLine];
}
const size_t expected = end - begin;
HWY_ASSERT(total == prev_total + expected);
prev_total += expected;
}
}
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
: topology(BoundedSlice(args.skip_packages, args.max_packages),
BoundedSlice(args.skip_clusters, args.max_clusters),
BoundedSlice(args.skip_lps, args.max_lps)),
allocator(topology, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) {}
pools(topology, allocator, args.max_threads, args.pin) {
PROFILER_ZONE("Startup.ThreadingContext autotune");
TunePool(pools.AllPackages());
for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) {
hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx);
TunePool(clusters);
// Run in parallel because Turin CPUs have 16, and in real usage, we often
// run all at the same time.
clusters.Run(0, clusters.NumWorkers(),
[&](uint64_t cluster_idx, size_t /*thread*/) {
TunePool(pools.Cluster(pkg_idx, cluster_idx));
});
}
}
} // namespace gcpp