From 667a3f117af5e25a4bd16d4c17eef72f3731591b Mon Sep 17 00:00:00 2001 From: Nitin Gangahar Date: Fri, 26 Sep 2025 11:27:56 -0700 Subject: [PATCH] Utilize multiple cores to read weight batches. PiperOrigin-RevId: 811893059 --- gemma/weights.cc | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 425a752..fb59297 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -465,18 +465,20 @@ static void ReadBatches(const BlobReader& reader, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches"); // >5x speedup from parallel reads when cached. - ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) { - PROFILER_ZONE3(ctx.profiler, thread, zone); - const IOBatch& batch = batches[i]; - const std::string& key = reader.Keys()[batch.KeyIdx()]; - const uint64_t bytes_read = batch.Read(reader.file()); - if (bytes_read != batch.TotalBytes()) { - HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(), - static_cast(batch.Offset()), - static_cast(batch.TotalBytes()), - static_cast(bytes_read)); - } - }); + ParallelFor(ParallelismStrategy::kHierarchical, + batches.size(), ctx, /*cluster_idx=*/0, + [&](uint64_t task, size_t thread) { + PROFILER_ZONE3(ctx.profiler, thread, zone); + const IOBatch& batch = batches[task]; + const std::string& key = reader.Keys()[batch.KeyIdx()]; + const uint64_t bytes_read = batch.Read(reader.file()); + if (bytes_read != batch.TotalBytes()) { + HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", + key.c_str(), static_cast(batch.Offset()), + static_cast(batch.TotalBytes()), + static_cast(bytes_read)); + } + }); } // Aborts on error. Updates `mode` to the actual mode used. Returns mapped