mirror of https://github.com/google/gemma.cpp.git
Utilize multiple cores to read weight batches.
PiperOrigin-RevId: 811893059
This commit is contained in:
parent
d15731d201
commit
667a3f117a
|
|
@ -465,18 +465,20 @@ static void ReadBatches(const BlobReader& reader,
|
|||
ThreadingContext& ctx) {
|
||||
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
||||
// >5x speedup from parallel reads when cached.
|
||||
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const IOBatch& batch = batches[i];
|
||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||
const uint64_t bytes_read = batch.Read(reader.file());
|
||||
if (bytes_read != batch.TotalBytes()) {
|
||||
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(),
|
||||
static_cast<size_t>(batch.Offset()),
|
||||
static_cast<size_t>(batch.TotalBytes()),
|
||||
static_cast<size_t>(bytes_read));
|
||||
}
|
||||
});
|
||||
ParallelFor(ParallelismStrategy::kHierarchical,
|
||||
batches.size(), ctx, /*cluster_idx=*/0,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const IOBatch& batch = batches[task];
|
||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||
const uint64_t bytes_read = batch.Read(reader.file());
|
||||
if (bytes_read != batch.TotalBytes()) {
|
||||
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.",
|
||||
key.c_str(), static_cast<size_t>(batch.Offset()),
|
||||
static_cast<size_t>(batch.TotalBytes()),
|
||||
static_cast<size_t>(bytes_read));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Aborts on error. Updates `mode` to the actual mode used. Returns mapped
|
||||
|
|
|
|||
Loading…
Reference in New Issue