Utilize multiple cores to read weight batches.

PiperOrigin-RevId: 811893059
This commit is contained in:
Nitin Gangahar 2025-09-26 11:27:56 -07:00 committed by Copybara-Service
parent d15731d201
commit 667a3f117a
1 changed files with 14 additions and 12 deletions

View File

@ -465,14 +465,16 @@ static void ReadBatches(const BlobReader& reader,
ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
// >5x speedup from parallel reads when cached.
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
ParallelFor(ParallelismStrategy::kHierarchical,
batches.size(), ctx, /*cluster_idx=*/0,
[&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const IOBatch& batch = batches[i];
const IOBatch& batch = batches[task];
const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file());
if (bytes_read != batch.TotalBytes()) {
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(),
static_cast<size_t>(batch.Offset()),
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.",
key.c_str(), static_cast<size_t>(batch.Offset()),
static_cast<size_t>(batch.TotalBytes()),
static_cast<size_t>(bytes_read));
}