mirror of https://github.com/google/gemma.cpp.git
3.8x speedup of weights loading via preadv on Linux
Also move BlobReader reading functionality to weights.cc PiperOrigin-RevId: 759240310
This commit is contained in:
parent
38a08d8095
commit
c443adee33
|
|
@ -218,12 +218,14 @@ cc_library(
|
||||||
":mat",
|
":mat",
|
||||||
":model_store",
|
":model_store",
|
||||||
":tensor_info",
|
":tensor_info",
|
||||||
|
":threading_context",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//io:blob_store",
|
"//io:blob_store",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@highway//:stats",
|
"@highway//:stats",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
"@highway//:timer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ HWY_EXPORT(NewSbsWriter);
|
||||||
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
||||||
|
|
||||||
SbsReader::SbsReader(const std::string& path)
|
SbsReader::SbsReader(const std::string& path)
|
||||||
: reader_(gcpp::BlobReader::Make(Path(path))), model_(*reader_) {}
|
: reader_(Path(path)), model_(reader_) {}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // HWY_ONCE
|
#endif // HWY_ONCE
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ class SbsReader {
|
||||||
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
|
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<gcpp::BlobReader> reader_;
|
gcpp::BlobReader reader_;
|
||||||
gcpp::ModelStore model_;
|
gcpp::ModelStore model_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,11 +53,11 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
||||||
|
|
||||||
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
||||||
: env_(env),
|
: env_(env),
|
||||||
reader_(BlobReader::Make(loader.weights, loader.map)),
|
reader_(new BlobReader(loader.weights)),
|
||||||
model_(*reader_, loader.tokenizer, loader.wrapping),
|
model_(*reader_, loader.tokenizer, loader.wrapping),
|
||||||
weights_(model_.Config().weight),
|
weights_(model_.Config().weight),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||||
weights_.ReadFromBlobs(model_, *reader_, env_.ctx.pools.Pool());
|
weights_.ReadFromBlobs(model_, *reader_, loader.map, env_.ctx.pools.Pool());
|
||||||
reader_.reset();
|
reader_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
150
gemma/weights.cc
150
gemma/weights.cc
|
|
@ -31,6 +31,7 @@
|
||||||
#include "gemma/model_store.h"
|
#include "gemma/model_store.h"
|
||||||
#include "io/blob_store.h"
|
#include "io/blob_store.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -95,37 +96,75 @@ void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
|
||||||
SplitW1NUQ(layer_config);
|
SplitW1NUQ(layer_config);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Aborts on error.
|
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
||||||
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
// better when the file is huge, but page faults add noise to measurements.
|
||||||
const std::vector<BlobRange>& ranges,
|
enum class Mode { kRead, kMap };
|
||||||
MatOwners& mat_owners, const MatPadding padding,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
HWY_ASSERT(mats.size() == ranges.size());
|
|
||||||
|
|
||||||
if (reader.IsMapped()) {
|
// Decides whether to read or map based on heuristics and user override.
|
||||||
PROFILER_ZONE("Startup.Weights.Map");
|
static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
|
||||||
for (size_t i = 0; i < mats.size(); ++i) {
|
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||||
// SetPtr does not change the stride, but it is expected to be packed
|
// User has explicitly requested a map or read via args.
|
||||||
// because that is what Compress() writes to the file.
|
if (map == Tristate::kTrue) return Mode::kMap;
|
||||||
const size_t mat_bytes = mats[i]->PackedBytes();
|
if (map == Tristate::kFalse) return Mode::kRead;
|
||||||
// Ensure blob size matches that computed from metadata.
|
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
|
||||||
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
|
// because idle memory is used as cache, so do not use it to decide.
|
||||||
|
const size_t file_mib = file_bytes >> 20;
|
||||||
hwy::Span<const uint8_t> span = reader.MappedSpan<uint8_t>(ranges[i]);
|
const size_t total_mib = allocator.TotalMiB();
|
||||||
HWY_ASSERT(span.size() == mat_bytes);
|
if (file_mib > total_mib) {
|
||||||
mats[i]->SetPtr(const_cast<uint8_t*>(span.data()), mats[i]->Stride());
|
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
||||||
}
|
static_cast<size_t>(file_mib), total_mib);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
// Large fraction of total.
|
||||||
|
if (file_mib >= total_mib / 3) return Mode::kMap;
|
||||||
|
// Big enough that even parallel loading wouldn't be quick.
|
||||||
|
if (file_mib > 50 * 1024) return Mode::kMap;
|
||||||
|
return Mode::kRead;
|
||||||
|
}
|
||||||
|
|
||||||
PROFILER_ZONE("Startup.Weights.AllocateAndEnqueue");
|
MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
|
||||||
|
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||||
|
if (file_bytes % allocator.BasePageBytes() == 0) {
|
||||||
|
MapPtr mapped = file.Map();
|
||||||
|
if (!mapped) {
|
||||||
|
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
||||||
|
static_cast<size_t>(file_bytes >> 10));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
||||||
|
static_cast<size_t>(file_bytes >> 10), allocator.BasePageBytes());
|
||||||
|
}
|
||||||
|
return MapPtr();
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: this changes the stride of `mats`!
|
static void MapAll(const std::vector<MatPtr*>& mats,
|
||||||
mat_owners.AllocateFor(mats, padding, pool);
|
const std::vector<BlobRange>& ranges, const MapPtr& mapped) {
|
||||||
|
PROFILER_ZONE("Startup.Weights.Map");
|
||||||
|
for (size_t i = 0; i < mats.size(); ++i) {
|
||||||
|
// SetPtr does not change the stride, but it is expected to be packed
|
||||||
|
// because that is what Compress() writes to the file.
|
||||||
|
const size_t mat_bytes = mats[i]->PackedBytes();
|
||||||
|
// Ensure blob size matches that computed from metadata.
|
||||||
|
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
|
||||||
|
|
||||||
|
mats[i]->SetPtr(const_cast<uint8_t*>(mapped.get() + ranges[i].offset),
|
||||||
|
mats[i]->Stride());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IOBatch> MakeBatches(const std::vector<BlobRange>& ranges,
|
||||||
|
const std::vector<MatPtr*>& mats,
|
||||||
|
const uint64_t file_bytes) {
|
||||||
|
PROFILER_ZONE("Startup.Weights.MakeBatches");
|
||||||
|
// Batches must be contiguous but blobs are padded, hence at least one
|
||||||
|
// batch per tensor, and more when tensor rows exceed the batch size.
|
||||||
|
std::vector<IOBatch> batches;
|
||||||
|
batches.reserve(mats.size());
|
||||||
|
|
||||||
// Enqueue the read requests, one per row in each tensor.
|
|
||||||
for (size_t i = 0; i < mats.size(); ++i) {
|
for (size_t i = 0; i < mats.size(); ++i) {
|
||||||
uint64_t offset = ranges[i].offset;
|
uint64_t offset = ranges[i].offset;
|
||||||
|
HWY_ASSERT(ranges[i].End() <= file_bytes);
|
||||||
|
|
||||||
|
batches.emplace_back(offset, ranges[i].key_idx);
|
||||||
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
|
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
|
||||||
// Caution, `RowT` requires knowledge of the actual type. We instead use
|
// Caution, `RowT` requires knowledge of the actual type. We instead use
|
||||||
// the first row, which is the same for any type, and advance the *byte*
|
// the first row, which is the same for any type, and advance the *byte*
|
||||||
|
|
@ -133,21 +172,70 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||||
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
|
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
|
||||||
uint8_t* row = mats[i]->RowT<uint8_t>(0);
|
uint8_t* row = mats[i]->RowT<uint8_t>(0);
|
||||||
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
|
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
|
||||||
reader.Enqueue(BlobRange{.offset = offset,
|
if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch.
|
||||||
.bytes = file_bytes_per_row,
|
batches.emplace_back(offset, ranges[i].key_idx);
|
||||||
.key_idx = ranges[i].key_idx},
|
// Adding to an empty batch is always successful.
|
||||||
row);
|
HWY_ASSERT(batches.back().Add(row, file_bytes_per_row));
|
||||||
|
}
|
||||||
offset += file_bytes_per_row;
|
offset += file_bytes_per_row;
|
||||||
row += mem_stride_bytes;
|
row += mem_stride_bytes;
|
||||||
// Keep the in-memory row padding uninitialized so msan detects any use.
|
// Keep the in-memory row padding uninitialized so msan detects any use.
|
||||||
}
|
}
|
||||||
|
HWY_ASSERT(offset == ranges[i].End());
|
||||||
}
|
}
|
||||||
|
|
||||||
reader.ReadAll(pool);
|
HWY_ASSERT(batches.size() >= mats.size());
|
||||||
|
return batches;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parallel synchronous I/O. Note that O_DIRECT seems undesirable because we
|
||||||
|
// want to use the OS cache between consecutive runs.
|
||||||
|
static void ReadBatches(const BlobReader& reader,
|
||||||
|
const std::vector<IOBatch>& batches,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
PROFILER_ZONE("Startup.Weights.Read");
|
||||||
|
// >5x speedup from parallel reads when cached.
|
||||||
|
pool.Run(0, batches.size(), [&](uint64_t i, size_t /*thread*/) {
|
||||||
|
const IOBatch& batch = batches[i];
|
||||||
|
const std::string& key = reader.Keys()[batch.KeyIdx()];
|
||||||
|
const uint64_t bytes_read = batch.Read(reader.file());
|
||||||
|
if (bytes_read != batch.TotalBytes()) {
|
||||||
|
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(),
|
||||||
|
static_cast<size_t>(batch.Offset()),
|
||||||
|
static_cast<size_t>(batch.TotalBytes()),
|
||||||
|
static_cast<size_t>(bytes_read));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aborts on error.
|
||||||
|
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||||
|
const std::vector<BlobRange>& ranges, Tristate map,
|
||||||
|
MatOwners& mat_owners, const MatPadding padding,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
HWY_ASSERT(mats.size() == ranges.size());
|
||||||
|
|
||||||
|
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
|
||||||
|
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
|
||||||
|
if (mapped) {
|
||||||
|
MapAll(mats, ranges, mapped);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} // otherwise fall through to read mode
|
||||||
|
|
||||||
|
{
|
||||||
|
PROFILER_ZONE("Startup.Weights.Allocate");
|
||||||
|
// NOTE: this changes the stride of `mats`!
|
||||||
|
mat_owners.AllocateFor(mats, padding, pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<IOBatch> batches =
|
||||||
|
MakeBatches(ranges, mats, reader.file_bytes());
|
||||||
|
ReadBatches(reader, batches, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
hwy::ThreadPool& pool) {
|
Tristate map, hwy::ThreadPool& pool) {
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
std::vector<MatPtr*> mats;
|
std::vector<MatPtr*> mats;
|
||||||
std::vector<BlobRange> ranges;
|
std::vector<BlobRange> ranges;
|
||||||
|
|
@ -171,7 +259,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
|
MapOrRead(mats, reader, ranges, map, mat_owners_, padding, pool);
|
||||||
|
|
||||||
Fixup(pool);
|
Fixup(pool);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -609,8 +609,9 @@ class WeightsOwner {
|
||||||
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
||||||
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
||||||
|
|
||||||
// Reads tensor data from `BlobStore` or aborts on error.
|
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
||||||
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
// override for whether to map blobs or read them.
|
||||||
|
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
||||||
|
|
@ -647,7 +648,7 @@ class WeightsOwner {
|
||||||
return float_weights_.get();
|
return float_weights_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Usually taken care of by `ReadOrAllocate`, but must also be called by
|
// Usually taken care of by `ReadFromBlobs`, but must also be called by
|
||||||
// `optimize_test, which updates the attention weights from which this copies.
|
// `optimize_test, which updates the attention weights from which this copies.
|
||||||
void Fixup(hwy::ThreadPool& pool);
|
void Fixup(hwy::ThreadPool& pool);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
#include <string.h> // strcmp
|
#include <string.h> // strcmp
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -108,11 +107,10 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||||
HWY_ASSERT(ranges.size() == blobs.size());
|
HWY_ASSERT(ranges.size() == blobs.size());
|
||||||
for (size_t i = 0; i < blobs.size(); ++i) {
|
pool.Run(0, blobs.size(), [&](size_t i, size_t /*thread*/) {
|
||||||
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
|
||||||
reader.Enqueue(ranges[i], blobs[i].data());
|
reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data());
|
||||||
}
|
});
|
||||||
reader.ReadAll(pool);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parallelizes ReadBlobs across (two) packages, if available.
|
// Parallelizes ReadBlobs across (two) packages, if available.
|
||||||
|
|
@ -213,20 +211,14 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compares two sbs files, including blob order.
|
// Compares two sbs files, including blob order.
|
||||||
void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
|
||||||
const Tristate map = Tristate::kFalse;
|
BlobReader reader1(path1);
|
||||||
std::unique_ptr<BlobReader> reader1 = BlobReader::Make(Path(path1), map);
|
BlobReader reader2(path2);
|
||||||
std::unique_ptr<BlobReader> reader2 = BlobReader::Make(Path(path2), map);
|
|
||||||
if (!reader1 || !reader2) {
|
|
||||||
HWY_ABORT(
|
|
||||||
"Failed to create readers for files %s %s, see error messages above.\n",
|
|
||||||
path1, path2);
|
|
||||||
}
|
|
||||||
|
|
||||||
CompareKeys(*reader1, *reader2);
|
CompareKeys(reader1, reader2);
|
||||||
const RangeVec ranges1 = AllRanges(reader1->Keys(), *reader1);
|
const RangeVec ranges1 = AllRanges(reader1.Keys(), reader1);
|
||||||
const RangeVec ranges2 = AllRanges(reader2->Keys(), *reader2);
|
const RangeVec ranges2 = AllRanges(reader2.Keys(), reader2);
|
||||||
CompareRangeSizes(reader1->Keys(), ranges1, ranges2);
|
CompareRangeSizes(reader1.Keys(), ranges1, ranges2);
|
||||||
|
|
||||||
// Single allocation, avoid initializing the memory.
|
// Single allocation, avoid initializing the memory.
|
||||||
const size_t total_bytes = TotalBytes(ranges1) + TotalBytes(ranges2);
|
const size_t total_bytes = TotalBytes(ranges1) + TotalBytes(ranges2);
|
||||||
|
|
@ -236,10 +228,10 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
||||||
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
|
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
|
||||||
|
|
||||||
NestedPools& pools = ThreadingContext::Get().pools;
|
NestedPools& pools = ThreadingContext::Get().pools;
|
||||||
ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1,
|
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
|
||||||
blobs2, pools);
|
pools);
|
||||||
|
|
||||||
CompareBlobs(reader1->Keys(), blobs1, blobs2, total_bytes, pools);
|
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -251,6 +243,6 @@ int main(int argc, char** argv) {
|
||||||
if (strcmp(argv[1], argv[2]) == 0) {
|
if (strcmp(argv[1], argv[2]) == 0) {
|
||||||
HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]);
|
HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]);
|
||||||
}
|
}
|
||||||
gcpp::ReadAndCompareBlobs(argv[1], argv[2]);
|
gcpp::ReadAndCompareBlobs(gcpp::Path(argv[1]), gcpp::Path(argv[2]));
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
124
io/blob_store.cc
124
io/blob_store.cc
|
|
@ -30,7 +30,6 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_compiler_arch.h"
|
#include "hwy/detect_compiler_arch.h"
|
||||||
#include "hwy/profiler.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -289,10 +288,12 @@ class BlobStore {
|
||||||
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
|
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
|
||||||
}; // BlobStore
|
}; // BlobStore
|
||||||
|
|
||||||
BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
|
BlobReader::BlobReader(const Path& blob_path)
|
||||||
const BlobStore& bs, BlobReader::Mode mode)
|
: file_(OpenFileOrAbort(blob_path, "r")), file_bytes_(file_->FileSize()) {
|
||||||
: file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) {
|
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
|
||||||
HWY_ASSERT(file_ && file_bytes_ != 0);
|
|
||||||
|
BlobStore bs(*file_);
|
||||||
|
HWY_ASSERT(bs.IsValid(file_bytes_)); // IsValid already printed a warning
|
||||||
|
|
||||||
keys_.reserve(bs.NumBlobs());
|
keys_.reserve(bs.NumBlobs());
|
||||||
for (const hwy::uint128_t key : bs.Keys()) {
|
for (const hwy::uint128_t key : bs.Keys()) {
|
||||||
|
|
@ -309,119 +310,6 @@ BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||||
BlobRange{.offset = offset, .bytes = bytes, .key_idx = key_idx});
|
BlobRange{.offset = offset, .bytes = bytes, .key_idx = key_idx});
|
||||||
key_idx_for_key_[keys_[key_idx]] = key_idx;
|
key_idx_for_key_[keys_[key_idx]] = key_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode_ == Mode::kMap) {
|
|
||||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
|
||||||
// Verify `kEndAlign` is an upper bound on the page size.
|
|
||||||
if (kEndAlign % allocator.BasePageBytes() != 0) {
|
|
||||||
HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.",
|
|
||||||
kEndAlign, allocator.BasePageBytes());
|
|
||||||
}
|
|
||||||
if (file_bytes_ % allocator.BasePageBytes() == 0) {
|
|
||||||
mapped_ = file_->Map();
|
|
||||||
if (!mapped_) {
|
|
||||||
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
|
|
||||||
static_cast<size_t>(file_bytes_ >> 10));
|
|
||||||
mode_ = Mode::kRead; // Switch to kRead and continue.
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
|
|
||||||
static_cast<size_t>(file_bytes_ >> 10),
|
|
||||||
allocator.BasePageBytes());
|
|
||||||
mode_ = Mode::kRead; // Switch to kRead and continue.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mode_ == Mode::kRead) {
|
|
||||||
// Potentially one per tensor row, so preallocate many.
|
|
||||||
requests_.reserve(2 << 20);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void BlobReader::Enqueue(const BlobRange& range, void* data) {
|
|
||||||
// Debug-only because there may be many I/O requests (per row).
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
HWY_DASSERT(!IsMapped());
|
|
||||||
HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr);
|
|
||||||
const BlobRange& blob_range = Range(range.key_idx);
|
|
||||||
HWY_DASSERT(blob_range.End() <= file_bytes_);
|
|
||||||
if (range.End() > blob_range.End()) {
|
|
||||||
HWY_ABORT(
|
|
||||||
"Bug: want to read %zu bytes of %s until %zu, past blob end %zu.",
|
|
||||||
range.bytes, keys_[range.key_idx].c_str(),
|
|
||||||
static_cast<size_t>(range.End()),
|
|
||||||
static_cast<size_t>(blob_range.End()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requests_.emplace_back(range, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parallel synchronous I/O. Alternatives considered:
|
|
||||||
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
|
|
||||||
// pread calls preadv with a single iovec.
|
|
||||||
// TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX.
|
|
||||||
// - O_DIRECT seems undesirable because we do want to use the OS cache
|
|
||||||
// between consecutive runs.
|
|
||||||
void BlobReader::ReadAll(hwy::ThreadPool& pool) const {
|
|
||||||
PROFILER_ZONE("Startup.ReadAll");
|
|
||||||
HWY_ASSERT(!IsMapped());
|
|
||||||
// >5x speedup from parallel reads when cached.
|
|
||||||
pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) {
|
|
||||||
const BlobRange& range = requests_[i].range;
|
|
||||||
const uint64_t end = range.End();
|
|
||||||
const std::string& key = keys_[range.key_idx];
|
|
||||||
const BlobRange& blob_range = Range(range.key_idx);
|
|
||||||
HWY_ASSERT(blob_range.End() <= file_bytes_);
|
|
||||||
if (end > blob_range.End()) {
|
|
||||||
HWY_ABORT(
|
|
||||||
"Bug: want to read %zu bytes of %s until %zu, past blob end %zu.",
|
|
||||||
range.bytes, key.c_str(), static_cast<size_t>(end),
|
|
||||||
static_cast<size_t>(blob_range.End()));
|
|
||||||
}
|
|
||||||
if (!file_->Read(range.offset, range.bytes, requests_[i].data)) {
|
|
||||||
HWY_ABORT("Read failed for %s from %zu, %zu bytes to %p.", key.c_str(),
|
|
||||||
static_cast<size_t>(range.offset), range.bytes,
|
|
||||||
requests_[i].data);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decides whether to read or map the file.
|
|
||||||
static BlobReader::Mode ChooseMode(uint64_t file_mib, Tristate map) {
|
|
||||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
|
||||||
// User has explicitly requested a map or read via args.
|
|
||||||
if (map == Tristate::kTrue) return BlobReader::Mode::kMap;
|
|
||||||
if (map == Tristate::kFalse) return BlobReader::Mode::kRead;
|
|
||||||
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
|
|
||||||
// because idle memory is used as cache, so do not use it to decide.
|
|
||||||
const size_t total_mib = allocator.TotalMiB();
|
|
||||||
if (file_mib > total_mib) {
|
|
||||||
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
|
|
||||||
static_cast<size_t>(file_mib), total_mib);
|
|
||||||
}
|
|
||||||
// Large fraction of total.
|
|
||||||
if (file_mib >= total_mib / 3) return BlobReader::Mode::kMap;
|
|
||||||
// Big enough that even parallel loading wouldn't be quick.
|
|
||||||
if (file_mib > 50 * 1024) return BlobReader::Mode::kMap;
|
|
||||||
return BlobReader::Mode::kRead;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<BlobReader> BlobReader::Make(const Path& blob_path,
|
|
||||||
const Tristate map) {
|
|
||||||
if (blob_path.Empty()) HWY_ABORT("No --weights specified.");
|
|
||||||
std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r");
|
|
||||||
if (!file) HWY_ABORT("Failed to open file %s", blob_path.path.c_str());
|
|
||||||
const uint64_t file_bytes = file->FileSize();
|
|
||||||
if (file_bytes == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
|
|
||||||
|
|
||||||
// Even if `kMap`, read the directory via the `kRead` mode for simplicity.
|
|
||||||
BlobStore bs(*file);
|
|
||||||
if (!bs.IsValid(file_bytes)) {
|
|
||||||
return std::unique_ptr<BlobReader>(); // IsValid already printed a warning
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::unique_ptr<BlobReader>(new BlobReader(
|
|
||||||
std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split into chunks for load-balancing even if blob sizes vary.
|
// Split into chunks for load-balancing even if blob sizes vary.
|
||||||
|
|
|
||||||
|
|
@ -55,25 +55,19 @@ struct BlobIO2 {
|
||||||
class BlobStore;
|
class BlobStore;
|
||||||
|
|
||||||
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
|
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
|
||||||
// faster lookups, and reads or maps blob data.
|
// faster lookups.
|
||||||
// Thread-safe: it is safe to concurrently call all methods except `Enqueue`,
|
// TODO(janwas): rename to BlobFinder or similar.
|
||||||
// because they are const.
|
// Thread-safe: it is safe to concurrently call all methods.
|
||||||
// TODO(janwas): split into header and reader/mapper classes.
|
|
||||||
class BlobReader {
|
class BlobReader {
|
||||||
public:
|
public:
|
||||||
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
|
||||||
// better when the file is huge, but page faults add noise to measurements.
|
|
||||||
enum class Mode { kRead, kMap };
|
|
||||||
|
|
||||||
// Acquires ownership of `file` (which must be non-null) and reads its header.
|
// Acquires ownership of `file` (which must be non-null) and reads its header.
|
||||||
// Factory function instead of ctor because this can fail (return null).
|
// Aborts on error.
|
||||||
static std::unique_ptr<BlobReader> Make(const Path& blob_path,
|
explicit BlobReader(const Path& blob_path);
|
||||||
Tristate map = Tristate::kDefault);
|
|
||||||
|
|
||||||
~BlobReader() = default;
|
// Non-const version required for File::Map().
|
||||||
|
File& file() { return *file_; }
|
||||||
// Returns true if the mode passed to ctor was `kMap` and mapping succeeded.
|
const File& file() const { return *file_; }
|
||||||
bool IsMapped() const { return mode_ == Mode::kMap; }
|
uint64_t file_bytes() const { return file_bytes_; }
|
||||||
|
|
||||||
const std::vector<std::string>& Keys() const { return keys_; }
|
const std::vector<std::string>& Keys() const { return keys_; }
|
||||||
|
|
||||||
|
|
@ -92,20 +86,8 @@ class BlobReader {
|
||||||
return ⦥
|
return ⦥
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that
|
|
||||||
// everything else except `CallWithSpan` is in units of bytes.
|
|
||||||
template <typename T>
|
|
||||||
hwy::Span<const T> MappedSpan(const BlobRange& range) const {
|
|
||||||
HWY_ASSERT(IsMapped());
|
|
||||||
HWY_ASSERT(range.bytes % sizeof(T) == 0);
|
|
||||||
return hwy::Span<const T>(
|
|
||||||
HWY_RCAST_ALIGNED(const T*, mapped_.get() + range.offset),
|
|
||||||
range.bytes / sizeof(T));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns error, or calls `func(span)` with the blob identified by `key`.
|
// Returns error, or calls `func(span)` with the blob identified by `key`.
|
||||||
// This may allocate memory for the blob, and is intended for small blobs for
|
// Allocates unaligned memory for the blob; intended for small metadata blobs.
|
||||||
// which an aligned allocation is unnecessary.
|
|
||||||
template <typename T, class Func>
|
template <typename T, class Func>
|
||||||
bool CallWithSpan(const std::string& key, const Func& func) const {
|
bool CallWithSpan(const std::string& key, const Func& func) const {
|
||||||
const BlobRange* range = Find(key);
|
const BlobRange* range = Find(key);
|
||||||
|
|
@ -114,11 +96,6 @@ class BlobReader {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode_ == Mode::kMap) {
|
|
||||||
func(MappedSpan<T>(*range));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_ASSERT(range->bytes % sizeof(T) == 0);
|
HWY_ASSERT(range->bytes % sizeof(T) == 0);
|
||||||
std::vector<T> storage(range->bytes / sizeof(T));
|
std::vector<T> storage(range->bytes / sizeof(T));
|
||||||
if (!file_->Read(range->offset, range->bytes, storage.data())) {
|
if (!file_->Read(range->offset, range->bytes, storage.data())) {
|
||||||
|
|
@ -131,30 +108,13 @@ class BlobReader {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following methods must only be called if `!IsMapped()`.
|
|
||||||
|
|
||||||
// Enqueues a BlobIO2 for `ReadAll` to execute.
|
|
||||||
void Enqueue(const BlobRange& range, void* data);
|
|
||||||
|
|
||||||
// Reads in parallel all enqueued requests to the specified destinations.
|
|
||||||
// Aborts on error.
|
|
||||||
void ReadAll(hwy::ThreadPool& pool) const;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Only for use by `Make`.
|
|
||||||
BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
|
|
||||||
const BlobStore& bs, Mode mode);
|
|
||||||
|
|
||||||
const std::unique_ptr<File> file_;
|
const std::unique_ptr<File> file_;
|
||||||
const uint64_t file_bytes_;
|
const uint64_t file_bytes_;
|
||||||
Mode mode_;
|
|
||||||
|
|
||||||
std::vector<std::string> keys_;
|
std::vector<std::string> keys_;
|
||||||
std::vector<BlobRange> ranges_;
|
std::vector<BlobRange> ranges_;
|
||||||
std::unordered_map<std::string, size_t> key_idx_for_key_;
|
std::unordered_map<std::string, size_t> key_idx_for_key_;
|
||||||
|
|
||||||
MapPtr mapped_; // only if `kMap`
|
|
||||||
std::vector<BlobIO2> requests_; // only if `kRead`
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Collects references to blobs and writes them all at once with parallel I/O.
|
// Collects references to blobs and writes them all at once with parallel I/O.
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <memory>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -36,7 +35,7 @@ namespace {
|
||||||
class BlobStoreTest : public testing::Test {};
|
class BlobStoreTest : public testing::Test {};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void TestWithMapped(Tristate map) {
|
TEST(BlobStoreTest, TestReadWrite) {
|
||||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
||||||
|
|
||||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||||
|
|
@ -59,43 +58,30 @@ void TestWithMapped(Tristate map) {
|
||||||
|
|
||||||
std::fill(buffer.begin(), buffer.end(), 0);
|
std::fill(buffer.begin(), buffer.end(), 0);
|
||||||
|
|
||||||
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
|
const BlobReader reader(path);
|
||||||
HWY_ASSERT(reader);
|
|
||||||
|
|
||||||
HWY_ASSERT_EQ(reader->Keys().size(), 2);
|
HWY_ASSERT_EQ(reader.Keys().size(), 2);
|
||||||
HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str());
|
HWY_ASSERT_STRING_EQ(reader.Keys()[0].c_str(), keyA.c_str());
|
||||||
HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str());
|
HWY_ASSERT_STRING_EQ(reader.Keys()[1].c_str(), keyB.c_str());
|
||||||
|
|
||||||
const BlobRange* range = reader->Find(keyA);
|
const BlobRange* range = reader.Find(keyA);
|
||||||
HWY_ASSERT(range);
|
HWY_ASSERT(range);
|
||||||
const uint64_t offsetA = range->offset;
|
const uint64_t offsetA = range->offset;
|
||||||
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
|
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
|
||||||
HWY_ASSERT_EQ(range->bytes, 5);
|
HWY_ASSERT_EQ(range->bytes, 5);
|
||||||
range = reader->Find(keyB);
|
range = reader.Find(keyB);
|
||||||
HWY_ASSERT(range);
|
HWY_ASSERT(range);
|
||||||
const uint64_t offsetB = range->offset;
|
const uint64_t offsetB = range->offset;
|
||||||
HWY_ASSERT_EQ(offsetB, 2 * 256);
|
HWY_ASSERT_EQ(offsetB, 2 * 256);
|
||||||
HWY_ASSERT_EQ(range->bytes, sizeof(buffer));
|
HWY_ASSERT_EQ(range->bytes, sizeof(buffer));
|
||||||
|
|
||||||
if (!reader->IsMapped()) {
|
|
||||||
char str[5];
|
|
||||||
reader->Enqueue(
|
|
||||||
BlobRange{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
|
|
||||||
reader->Enqueue(
|
|
||||||
BlobRange{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
|
|
||||||
buffer.data());
|
|
||||||
reader->ReadAll(pool);
|
|
||||||
HWY_ASSERT_STRING_EQ("DATA", str);
|
|
||||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(
|
||||||
reader->CallWithSpan<char>(keyA, [](const hwy::Span<const char> span) {
|
reader.CallWithSpan<char>(keyA, [](const hwy::Span<const char> span) {
|
||||||
HWY_ASSERT_EQ(span.size(), 5);
|
HWY_ASSERT_EQ(span.size(), 5);
|
||||||
HWY_ASSERT_STRING_EQ("DATA", span.data());
|
HWY_ASSERT_STRING_EQ("DATA", span.data());
|
||||||
}));
|
}));
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(
|
||||||
reader->CallWithSpan<float>(keyB, [](const hwy::Span<const float> span) {
|
reader.CallWithSpan<float>(keyB, [](const hwy::Span<const float> span) {
|
||||||
HWY_ASSERT_EQ(span.size(), 4);
|
HWY_ASSERT_EQ(span.size(), 4);
|
||||||
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), span.data(), span.size());
|
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), span.data(), span.size());
|
||||||
}));
|
}));
|
||||||
|
|
@ -104,11 +90,6 @@ void TestWithMapped(Tristate map) {
|
||||||
unlink(path_str);
|
unlink(path_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BlobStoreTest, TestReadWrite) {
|
|
||||||
TestWithMapped(Tristate::kFalse);
|
|
||||||
TestWithMapped(Tristate::kTrue);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensures padding works for any number of random-sized blobs.
|
// Ensures padding works for any number of random-sized blobs.
|
||||||
TEST(BlobStoreTest, TestNumBlobs) {
|
TEST(BlobStoreTest, TestNumBlobs) {
|
||||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
||||||
|
|
@ -143,17 +124,14 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
||||||
HWY_ASSERT(blobs.size() == num_blobs);
|
HWY_ASSERT(blobs.size() == num_blobs);
|
||||||
writer.WriteAll(pool, path);
|
writer.WriteAll(pool, path);
|
||||||
|
|
||||||
const Tristate map = Tristate::kFalse;
|
BlobReader reader(path);
|
||||||
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
|
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
|
||||||
HWY_ASSERT(reader);
|
|
||||||
HWY_ASSERT_EQ(reader->Keys().size(), num_blobs);
|
|
||||||
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
|
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
|
||||||
HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(),
|
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str());
|
||||||
std::to_string(i).c_str());
|
const BlobRange* range = reader.Find(keys[i]);
|
||||||
const BlobRange* range = reader->Find(keys[i]);
|
|
||||||
HWY_ASSERT(range);
|
HWY_ASSERT(range);
|
||||||
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
|
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
|
||||||
HWY_ASSERT(reader->CallWithSpan<uint8_t>(
|
HWY_ASSERT(reader.CallWithSpan<uint8_t>(
|
||||||
keys[i], [path_str, num_blobs, i, range,
|
keys[i], [path_str, num_blobs, i, range,
|
||||||
&blobs](const hwy::Span<const uint8_t> span) {
|
&blobs](const hwy::Span<const uint8_t> span) {
|
||||||
HWY_ASSERT_EQ(blobs[i].size(), span.size());
|
HWY_ASSERT_EQ(blobs[i].size(), span.size());
|
||||||
|
|
|
||||||
111
io/io.cc
111
io/io.cc
|
|
@ -19,23 +19,45 @@
|
||||||
// check this in source code because we support multiple build systems.
|
// check this in source code because we support multiple build systems.
|
||||||
#if !HWY_OS_WIN
|
#if !HWY_OS_WIN
|
||||||
|
|
||||||
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
|
// Request POSIX 2008, including `pread()` and `posix_fadvise()`. This also
|
||||||
|
// implies `_POSIX_C_SOURCE`.
|
||||||
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
|
||||||
#undef _XOPEN_SOURCE
|
#undef _XOPEN_SOURCE
|
||||||
#define _XOPEN_SOURCE 700
|
#define _XOPEN_SOURCE 700 // SUSv4
|
||||||
#endif
|
|
||||||
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
|
|
||||||
#define _POSIX_C_SOURCE 200809
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
|
||||||
#undef _FILE_OFFSET_BITS
|
#undef _FILE_OFFSET_BITS
|
||||||
#define _FILE_OFFSET_BITS 64
|
#define _FILE_OFFSET_BITS 64
|
||||||
|
|
||||||
#include <fcntl.h> // open
|
#if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \
|
||||||
|
(!defined(__ANDROID_API__) || __ANDROID_API__ >= 24)
|
||||||
|
#define GEMMA_IO_PREADV 1
|
||||||
|
#else
|
||||||
|
#define GEMMA_IO_PREADV 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \
|
||||||
|
(!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
||||||
|
#define GEMMA_IO_FADVISE 1
|
||||||
|
#else
|
||||||
|
#define GEMMA_IO_FADVISE 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if GEMMA_IO_PREADV
|
||||||
|
// Replacement for the _BSD_SOURCE specified by preadv documentation.
|
||||||
|
#ifndef _DEFAULT_SOURCE
|
||||||
|
#define _DEFAULT_SOURCE
|
||||||
|
#endif
|
||||||
|
#include <errno.h>
|
||||||
|
#include <sys/uio.h> // preadv
|
||||||
|
#endif // GEMMA_IO_PREADV
|
||||||
|
|
||||||
|
#include <fcntl.h> // open
|
||||||
|
#include <limits.h> // IOV_MAX
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
// Old OSX may require sys/types.h before sys/mman.h.
|
// Old OSX may require sys/types.h before sys/mman.h.
|
||||||
#include <sys/mman.h> // mmap
|
#include <sys/mman.h> // mmap
|
||||||
|
|
@ -43,6 +65,7 @@
|
||||||
#include <unistd.h> // read, write, close
|
#include <unistd.h> // read, write, close
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "io/io.h"
|
#include "io/io.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
|
@ -119,6 +142,8 @@ class FilePosix : public File {
|
||||||
HWY_ASSERT(munmap(ptr, mapping_size) == 0);
|
HWY_ASSERT(munmap(ptr, mapping_size) == 0);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int Handle() const override { return fd_; }
|
||||||
}; // FilePosix
|
}; // FilePosix
|
||||||
|
|
||||||
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
|
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
|
||||||
|
|
@ -133,15 +158,83 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
||||||
const int fd = open(filename.path.c_str(), flags, 0644);
|
const int fd = open(filename.path.c_str(), flags, 0644);
|
||||||
if (fd < 0) return file;
|
if (fd < 0) return file;
|
||||||
|
|
||||||
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
|
#if GEMMA_IO_FADVISE
|
||||||
if (is_read) {
|
if (is_read) {
|
||||||
// Doubles the readahead window, which seems slightly faster when cached.
|
// Doubles the readahead window, which seems slightly faster when cached.
|
||||||
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
|
||||||
}
|
}
|
||||||
#endif
|
#endif // GEMMA_IO_FADVISE
|
||||||
|
|
||||||
return std::make_unique<FilePosix>(fd);
|
return std::make_unique<FilePosix>(fd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
|
||||||
|
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
|
||||||
|
if (!file) {
|
||||||
|
HWY_ABORT("Failed to open %s", filename.path.c_str());
|
||||||
|
}
|
||||||
|
return file;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ReadFileToString(const Path& path) {
|
||||||
|
std::unique_ptr<File> file = OpenFileOrAbort(path, "r");
|
||||||
|
const size_t size = file->FileSize();
|
||||||
|
if (size == 0) {
|
||||||
|
HWY_ABORT("Empty file %s", path.path.c_str());
|
||||||
|
}
|
||||||
|
std::string content(size, ' ');
|
||||||
|
if (!file->Read(0, size, content.data())) {
|
||||||
|
HWY_ABORT("Failed to read %s", path.path.c_str());
|
||||||
|
}
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef IOV_MAX
|
||||||
|
constexpr size_t kMaxSpans = IOV_MAX;
|
||||||
|
#else
|
||||||
|
constexpr size_t kMaxSpans = 1024; // Linux limit
|
||||||
|
#endif
|
||||||
|
|
||||||
|
IOBatch::IOBatch(uint64_t offset, size_t key_idx)
|
||||||
|
: offset_(offset), key_idx_(key_idx) {
|
||||||
|
spans_.reserve(kMaxSpans);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the batch was full; if so, call again on the new batch.
|
||||||
|
bool IOBatch::Add(void* mem, size_t bytes) {
|
||||||
|
if (spans_.size() >= kMaxSpans) return false;
|
||||||
|
if (total_bytes_ + bytes > 0x7FFFF000) return false; // Linux limit
|
||||||
|
spans_.push_back({.mem = mem, .bytes = bytes});
|
||||||
|
total_bytes_ += bytes;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t IOBatch::Read(const File& file) const {
|
||||||
|
#if GEMMA_IO_PREADV
|
||||||
|
HWY_ASSERT(!spans_.empty());
|
||||||
|
|
||||||
|
ssize_t bytes_read;
|
||||||
|
for (;;) {
|
||||||
|
bytes_read =
|
||||||
|
preadv(file.Handle(), reinterpret_cast<const iovec*>(spans_.data()),
|
||||||
|
static_cast<int>(spans_.size()), offset_);
|
||||||
|
if (bytes_read >= 0) break;
|
||||||
|
if (errno == EINTR) continue; // signal: retry
|
||||||
|
HWY_WARN("preadv failed, errno %d.", errno);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return static_cast<uint64_t>(bytes_read);
|
||||||
|
#else
|
||||||
|
uint64_t total = 0;
|
||||||
|
uint64_t offset = offset_;
|
||||||
|
for (const IOSpan& span : spans_) {
|
||||||
|
if (!file.Read(offset, span.bytes, span.mem)) return 0;
|
||||||
|
total += span.bytes;
|
||||||
|
offset += span.bytes;
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
#endif // !HWY_OS_WIN
|
#endif // !HWY_OS_WIN
|
||||||
|
|
|
||||||
63
io/io.h
63
io/io.h
|
|
@ -22,6 +22,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility> // std::move
|
#include <utility> // std::move
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -49,22 +50,66 @@ class File {
|
||||||
virtual uint64_t FileSize() const = 0;
|
virtual uint64_t FileSize() const = 0;
|
||||||
|
|
||||||
// Returns true if all the requested bytes were read.
|
// Returns true if all the requested bytes were read.
|
||||||
|
// Thread-compatible.
|
||||||
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
|
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
|
||||||
|
|
||||||
// Returns true if all the requested bytes were written.
|
// Returns true if all the requested bytes were written.
|
||||||
|
// Thread-compatible.
|
||||||
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
|
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
|
||||||
|
|
||||||
// Maps the entire file into read-only memory or returns nullptr on failure.
|
// Maps the entire file into read-only memory or returns nullptr on failure.
|
||||||
// We do not support offsets because Windows requires them to be a multiple of
|
// We do not support offsets because Windows requires them to be a multiple of
|
||||||
// the allocation granularity, which is 64 KiB. Some implementations may fail
|
// the allocation granularity, which is 64 KiB. Some implementations may fail
|
||||||
// if the file is zero-sized and return a nullptr.
|
// if the file is zero-sized and return a nullptr. Non-const because it may
|
||||||
|
// modify internal state. This is only expected to be called once per file.
|
||||||
virtual MapPtr Map() = 0;
|
virtual MapPtr Map() = 0;
|
||||||
|
|
||||||
|
// For use by `IOBatch::Read`.
|
||||||
|
virtual int Handle() const { return -1; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
|
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
|
||||||
// named 'OpenFile' to avoid a conflict with Windows.h #define.
|
// named 'OpenFile' to avoid a conflict with Windows.h #define.
|
||||||
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
|
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
|
||||||
|
|
||||||
|
// As above, but aborts on instead of returning nullptr.
|
||||||
|
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode);
|
||||||
|
|
||||||
|
// Compatible with Linux iovec.
|
||||||
|
struct IOSpan {
|
||||||
|
void* mem;
|
||||||
|
size_t bytes;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrapper for Linux/BSD `preadv`, calling `File::Read` on other systems. To
|
||||||
|
// insert row padding, we previously issued one IO per tensor row, which is
|
||||||
|
// expensive. `preadv` reduces up to 1024 syscalls to 1.
|
||||||
|
// The file data must be contiguous starting from `IOBatch::offset_`, because
|
||||||
|
// `preadv` does not support per-`IOSpan` offsets.
|
||||||
|
class IOBatch {
|
||||||
|
public:
|
||||||
|
// Reserves memory in `spans_`. `key_idx` identifies the blob/tensor.
|
||||||
|
explicit IOBatch(uint64_t offset, size_t key_idx);
|
||||||
|
|
||||||
|
// The next `bytes` will be read from file into `mem`.
|
||||||
|
// Returns true if the batch was full; if so, call again on the new batch.
|
||||||
|
bool Add(void* mem, size_t bytes);
|
||||||
|
|
||||||
|
uint64_t Offset() const { return offset_; }
|
||||||
|
uint64_t TotalBytes() const { return total_bytes_; }
|
||||||
|
size_t KeyIdx() const { return key_idx_; }
|
||||||
|
|
||||||
|
// Returns the total number of bytes read, or 0 if any I/O failed.
|
||||||
|
// Thread-compatible.
|
||||||
|
uint64_t Read(const File& file) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t offset_;
|
||||||
|
uint64_t total_bytes_ = 0;
|
||||||
|
size_t key_idx_;
|
||||||
|
std::vector<IOSpan> spans_; // contiguous in the file.
|
||||||
|
};
|
||||||
|
|
||||||
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
|
||||||
// strings and supports shortening for display purposes.
|
// strings and supports shortening for display purposes.
|
||||||
struct Path {
|
struct Path {
|
||||||
|
|
@ -97,21 +142,7 @@ struct Path {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Aborts on error.
|
// Aborts on error.
|
||||||
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
|
std::string ReadFileToString(const Path& path);
|
||||||
std::unique_ptr<File> file = OpenFileOrNull(path, "r");
|
|
||||||
if (!file) {
|
|
||||||
HWY_ABORT("Failed to open %s", path.path.c_str());
|
|
||||||
}
|
|
||||||
const size_t size = file->FileSize();
|
|
||||||
if (size == 0) {
|
|
||||||
HWY_ABORT("Empty file %s", path.path.c_str());
|
|
||||||
}
|
|
||||||
std::string content(size, ' ');
|
|
||||||
if (!file->Read(0, size, content.data())) {
|
|
||||||
HWY_ABORT("Failed to read %s", path.path.c_str());
|
|
||||||
}
|
|
||||||
return content;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue