Utilize multiple cores to read weight batches.

PiperOrigin-RevId: 811893059
This commit is contained in:
Nitin Gangahar 2025-09-26 11:27:56 -07:00 committed by Copybara-Service
parent d15731d201
commit 667a3f117a
1 changed files with 14 additions and 12 deletions

View File

@ -465,18 +465,20 @@ static void ReadBatches(const BlobReader& reader,
ThreadingContext& ctx) { ThreadingContext& ctx) {
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches"); static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) { ParallelFor(ParallelismStrategy::kHierarchical,
PROFILER_ZONE3(ctx.profiler, thread, zone); batches.size(), ctx, /*cluster_idx=*/0,
const IOBatch& batch = batches[i]; [&](uint64_t task, size_t thread) {
const std::string& key = reader.Keys()[batch.KeyIdx()]; PROFILER_ZONE3(ctx.profiler, thread, zone);
const uint64_t bytes_read = batch.Read(reader.file()); const IOBatch& batch = batches[task];
if (bytes_read != batch.TotalBytes()) { const std::string& key = reader.Keys()[batch.KeyIdx()];
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(), const uint64_t bytes_read = batch.Read(reader.file());
static_cast<size_t>(batch.Offset()), if (bytes_read != batch.TotalBytes()) {
static_cast<size_t>(batch.TotalBytes()), HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.",
static_cast<size_t>(bytes_read)); key.c_str(), static_cast<size_t>(batch.Offset()),
} static_cast<size_t>(batch.TotalBytes()),
}); static_cast<size_t>(bytes_read));
}
});
} }
// Aborts on error. Updates `mode` to the actual mode used. Returns mapped // Aborts on error. Updates `mode` to the actual mode used. Returns mapped