mirror of https://github.com/google/gemma.cpp.git
Utilize multiple cores to read weight batches.
PiperOrigin-RevId: 811893059
This commit is contained in:
parent
d15731d201
commit
667a3f117a
|
|
@ -465,14 +465,16 @@ static void ReadBatches(const BlobReader& reader,
|
||||||
ThreadingContext& ctx) {
|
ThreadingContext& ctx) {
|
||||||
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadBatches");
|
||||||
// >5x speedup from parallel reads when cached.
|
// >5x speedup from parallel reads when cached.
|
||||||
ctx.pools.Pool().Run(0, batches.size(), [&](uint64_t i, size_t thread) {
|
ParallelFor(ParallelismStrategy::kHierarchical,
|
||||||
|
batches.size(), ctx, /*cluster_idx=*/0,
|
||||||
|
[&](uint64_t task, size_t thread) {
|
||||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||||
const IOBatch& batch = batches[i];
|
const IOBatch& batch = batches[task];
|
||||||
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||||
const uint64_t bytes_read = batch.Read(reader.file());
|
const uint64_t bytes_read = batch.Read(reader.file());
|
||||||
if (bytes_read != batch.TotalBytes()) {
|
if (bytes_read != batch.TotalBytes()) {
|
||||||
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(),
|
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.",
|
||||||
static_cast<size_t>(batch.Offset()),
|
key.c_str(), static_cast<size_t>(batch.Offset()),
|
||||||
static_cast<size_t>(batch.TotalBytes()),
|
static_cast<size_t>(batch.TotalBytes()),
|
||||||
static_cast<size_t>(bytes_read));
|
static_cast<size_t>(bytes_read));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue