3.8x speedup of weights loading via preadv on Linux

Also move BlobReader reading functionality to weights.cc

PiperOrigin-RevId: 759240310
This commit is contained in:
Jan Wassenberg 2025-05-15 11:54:38 -07:00 committed by Copybara-Service
parent 38a08d8095
commit c443adee33
12 changed files with 322 additions and 289 deletions

View File

@ -218,12 +218,14 @@ cc_library(
":mat",
":model_store",
":tensor_info",
":threading_context",
"//compression:compress",
"//io:blob_store",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//:timer",
],
)

View File

@ -141,7 +141,7 @@ HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsReader::SbsReader(const std::string& path)
: reader_(gcpp::BlobReader::Make(Path(path))), model_(*reader_) {}
: reader_(Path(path)), model_(reader_) {}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -77,7 +77,7 @@ class SbsReader {
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
private:
std::unique_ptr<gcpp::BlobReader> reader_;
gcpp::BlobReader reader_;
gcpp::ModelStore model_;
};

View File

@ -53,11 +53,11 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
: env_(env),
reader_(BlobReader::Make(loader.weights, loader.map)),
reader_(new BlobReader(loader.weights)),
model_(*reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config().weight),
chat_template_(model_.Tokenizer(), model_.Config().model) {
weights_.ReadFromBlobs(model_, *reader_, env_.ctx.pools.Pool());
weights_.ReadFromBlobs(model_, *reader_, loader.map, env_.ctx.pools.Pool());
reader_.reset();
}

View File

@ -31,6 +31,7 @@
#include "gemma/model_store.h"
#include "io/blob_store.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h"
@ -95,14 +96,48 @@ void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
SplitW1NUQ(layer_config);
}
// Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges,
MatOwners& mat_owners, const MatPadding padding,
hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size());
// Parallel I/O into allocated memory, or mapped view of file. The latter is
// better when the file is huge, but page faults add noise to measurements.
enum class Mode { kRead, kMap };
if (reader.IsMapped()) {
// Decides whether to read or map based on heuristics and user override.
static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
const Allocator& allocator = ThreadingContext::Get().allocator;
// User has explicitly requested a map or read via args.
if (map == Tristate::kTrue) return Mode::kMap;
if (map == Tristate::kFalse) return Mode::kRead;
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
// because idle memory is used as cache, so do not use it to decide.
const size_t file_mib = file_bytes >> 20;
const size_t total_mib = allocator.TotalMiB();
if (file_mib > total_mib) {
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
static_cast<size_t>(file_mib), total_mib);
}
// Large fraction of total.
if (file_mib >= total_mib / 3) return Mode::kMap;
// Big enough that even parallel loading wouldn't be quick.
if (file_mib > 50 * 1024) return Mode::kMap;
return Mode::kRead;
}
MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
const Allocator& allocator = ThreadingContext::Get().allocator;
if (file_bytes % allocator.BasePageBytes() == 0) {
MapPtr mapped = file.Map();
if (!mapped) {
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
static_cast<size_t>(file_bytes >> 10));
}
} else {
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
static_cast<size_t>(file_bytes >> 10), allocator.BasePageBytes());
}
return MapPtr();
}
static void MapAll(const std::vector<MatPtr*>& mats,
const std::vector<BlobRange>& ranges, const MapPtr& mapped) {
PROFILER_ZONE("Startup.Weights.Map");
for (size_t i = 0; i < mats.size(); ++i) {
// SetPtr does not change the stride, but it is expected to be packed
@ -111,21 +146,25 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
// Ensure blob size matches that computed from metadata.
HWY_ASSERT_M(mat_bytes == ranges[i].bytes, mats[i]->Name());
hwy::Span<const uint8_t> span = reader.MappedSpan<uint8_t>(ranges[i]);
HWY_ASSERT(span.size() == mat_bytes);
mats[i]->SetPtr(const_cast<uint8_t*>(span.data()), mats[i]->Stride());
}
return;
mats[i]->SetPtr(const_cast<uint8_t*>(mapped.get() + ranges[i].offset),
mats[i]->Stride());
}
}
PROFILER_ZONE("Startup.Weights.AllocateAndEnqueue");
std::vector<IOBatch> MakeBatches(const std::vector<BlobRange>& ranges,
const std::vector<MatPtr*>& mats,
const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches");
// Batches must be contiguous but blobs are padded, hence at least one
// batch per tensor, and more when tensor rows exceed the batch size.
std::vector<IOBatch> batches;
batches.reserve(mats.size());
// NOTE: this changes the stride of `mats`!
mat_owners.AllocateFor(mats, padding, pool);
// Enqueue the read requests, one per row in each tensor.
for (size_t i = 0; i < mats.size(); ++i) {
uint64_t offset = ranges[i].offset;
HWY_ASSERT(ranges[i].End() <= file_bytes);
batches.emplace_back(offset, ranges[i].key_idx);
const size_t file_bytes_per_row = mats[i]->Cols() * mats[i]->ElementBytes();
// Caution, `RowT` requires knowledge of the actual type. We instead use
// the first row, which is the same for any type, and advance the *byte*
@ -133,21 +172,70 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
uint8_t* row = mats[i]->RowT<uint8_t>(0);
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
reader.Enqueue(BlobRange{.offset = offset,
.bytes = file_bytes_per_row,
.key_idx = ranges[i].key_idx},
row);
if (!batches.back().Add(row, file_bytes_per_row)) { // Full batch.
batches.emplace_back(offset, ranges[i].key_idx);
// Adding to an empty batch is always successful.
HWY_ASSERT(batches.back().Add(row, file_bytes_per_row));
}
offset += file_bytes_per_row;
row += mem_stride_bytes;
// Keep the in-memory row padding uninitialized so msan detects any use.
}
HWY_ASSERT(offset == ranges[i].End());
}
reader.ReadAll(pool);
HWY_ASSERT(batches.size() >= mats.size());
return batches;
}
// Parallel synchronous I/O. Note that O_DIRECT seems undesirable because we
// want to use the OS cache between consecutive runs.
static void ReadBatches(const BlobReader& reader,
const std::vector<IOBatch>& batches,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Weights.Read");
// >5x speedup from parallel reads when cached.
pool.Run(0, batches.size(), [&](uint64_t i, size_t /*thread*/) {
const IOBatch& batch = batches[i];
const std::string& key = reader.Keys()[batch.KeyIdx()];
const uint64_t bytes_read = batch.Read(reader.file());
if (bytes_read != batch.TotalBytes()) {
HWY_ABORT("Read failed for %s from %zu, %zu bytes; got %zu.", key.c_str(),
static_cast<size_t>(batch.Offset()),
static_cast<size_t>(batch.TotalBytes()),
static_cast<size_t>(bytes_read));
}
});
}
// Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges, Tristate map,
MatOwners& mat_owners, const MatPadding padding,
hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size());
if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
MapPtr mapped = MapFileOrNull(reader.file(), reader.file_bytes());
if (mapped) {
MapAll(mats, ranges, mapped);
return;
}
} // otherwise fall through to read mode
{
PROFILER_ZONE("Startup.Weights.Allocate");
// NOTE: this changes the stride of `mats`!
mat_owners.AllocateFor(mats, padding, pool);
}
const std::vector<IOBatch> batches =
MakeBatches(ranges, mats, reader.file_bytes());
ReadBatches(reader, batches, pool);
}
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool) {
Tristate map, hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
std::vector<MatPtr*> mats;
std::vector<BlobRange> ranges;
@ -171,7 +259,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
});
});
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
MapOrRead(mats, reader, ranges, map, mat_owners_, padding, pool);
Fixup(pool);
}

View File

@ -609,8 +609,9 @@ class WeightsOwner {
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
// Reads tensor data from `BlobStore` or aborts on error.
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
// override for whether to map blobs or read them.
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
hwy::ThreadPool& pool);
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
@ -647,7 +648,7 @@ class WeightsOwner {
return float_weights_.get();
}
// Usually taken care of by `ReadOrAllocate`, but must also be called by
// Usually taken care of by `ReadFromBlobs`, but must also be called by
// `optimize_test, which updates the attention weights from which this copies.
void Fixup(hwy::ThreadPool& pool);

View File

@ -18,7 +18,6 @@
#include <string.h> // strcmp
#include <atomic>
#include <memory>
#include <string>
#include <vector>
@ -108,11 +107,10 @@ void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
hwy::ThreadPool& pool) {
HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size());
for (size_t i = 0; i < blobs.size(); ++i) {
pool.Run(0, blobs.size(), [&](size_t i, size_t /*thread*/) {
HWY_ASSERT(ranges[i].bytes == blobs[i].size());
reader.Enqueue(ranges[i], blobs[i].data());
}
reader.ReadAll(pool);
reader.file().Read(ranges[i].offset, ranges[i].bytes, blobs[i].data());
});
}
// Parallelizes ReadBlobs across (two) packages, if available.
@ -213,20 +211,14 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
}
// Compares two sbs files, including blob order.
void ReadAndCompareBlobs(const char* path1, const char* path2) {
const Tristate map = Tristate::kFalse;
std::unique_ptr<BlobReader> reader1 = BlobReader::Make(Path(path1), map);
std::unique_ptr<BlobReader> reader2 = BlobReader::Make(Path(path2), map);
if (!reader1 || !reader2) {
HWY_ABORT(
"Failed to create readers for files %s %s, see error messages above.\n",
path1, path2);
}
void ReadAndCompareBlobs(const Path& path1, const Path& path2) {
BlobReader reader1(path1);
BlobReader reader2(path2);
CompareKeys(*reader1, *reader2);
const RangeVec ranges1 = AllRanges(reader1->Keys(), *reader1);
const RangeVec ranges2 = AllRanges(reader2->Keys(), *reader2);
CompareRangeSizes(reader1->Keys(), ranges1, ranges2);
CompareKeys(reader1, reader2);
const RangeVec ranges1 = AllRanges(reader1.Keys(), reader1);
const RangeVec ranges2 = AllRanges(reader2.Keys(), reader2);
CompareRangeSizes(reader1.Keys(), ranges1, ranges2);
// Single allocation, avoid initializing the memory.
const size_t total_bytes = TotalBytes(ranges1) + TotalBytes(ranges2);
@ -236,10 +228,10 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
NestedPools& pools = ThreadingContext::Get().pools;
ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1,
blobs2, pools);
ReadBothBlobs(reader1, reader2, ranges1, ranges2, total_bytes, blobs1, blobs2,
pools);
CompareBlobs(reader1->Keys(), blobs1, blobs2, total_bytes, pools);
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
}
} // namespace gcpp
@ -251,6 +243,6 @@ int main(int argc, char** argv) {
if (strcmp(argv[1], argv[2]) == 0) {
HWY_ABORT("Filenames are the same, skipping comparison: %s\n", argv[1]);
}
gcpp::ReadAndCompareBlobs(argv[1], argv[2]);
gcpp::ReadAndCompareBlobs(gcpp::Path(argv[1]), gcpp::Path(argv[2]));
return 0;
}

View File

@ -30,7 +30,6 @@
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_compiler_arch.h"
#include "hwy/profiler.h"
namespace gcpp {
@ -289,10 +288,12 @@ class BlobStore {
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
}; // BlobStore
BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, BlobReader::Mode mode)
: file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) {
HWY_ASSERT(file_ && file_bytes_ != 0);
BlobReader::BlobReader(const Path& blob_path)
: file_(OpenFileOrAbort(blob_path, "r")), file_bytes_(file_->FileSize()) {
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
BlobStore bs(*file_);
HWY_ASSERT(bs.IsValid(file_bytes_)); // IsValid already printed a warning
keys_.reserve(bs.NumBlobs());
for (const hwy::uint128_t key : bs.Keys()) {
@ -309,119 +310,6 @@ BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
BlobRange{.offset = offset, .bytes = bytes, .key_idx = key_idx});
key_idx_for_key_[keys_[key_idx]] = key_idx;
}
if (mode_ == Mode::kMap) {
const Allocator& allocator = ThreadingContext::Get().allocator;
// Verify `kEndAlign` is an upper bound on the page size.
if (kEndAlign % allocator.BasePageBytes() != 0) {
HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.",
kEndAlign, allocator.BasePageBytes());
}
if (file_bytes_ % allocator.BasePageBytes() == 0) {
mapped_ = file_->Map();
if (!mapped_) {
HWY_WARN("Failed to map file (%zu KiB), reading instead.",
static_cast<size_t>(file_bytes_ >> 10));
mode_ = Mode::kRead; // Switch to kRead and continue.
}
} else {
HWY_WARN("Unable to map non-padded file (%zu, %zu), reading instead.",
static_cast<size_t>(file_bytes_ >> 10),
allocator.BasePageBytes());
mode_ = Mode::kRead; // Switch to kRead and continue.
}
}
if (mode_ == Mode::kRead) {
// Potentially one per tensor row, so preallocate many.
requests_.reserve(2 << 20);
}
}
void BlobReader::Enqueue(const BlobRange& range, void* data) {
// Debug-only because there may be many I/O requests (per row).
if constexpr (HWY_IS_DEBUG_BUILD) {
HWY_DASSERT(!IsMapped());
HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr);
const BlobRange& blob_range = Range(range.key_idx);
HWY_DASSERT(blob_range.End() <= file_bytes_);
if (range.End() > blob_range.End()) {
HWY_ABORT(
"Bug: want to read %zu bytes of %s until %zu, past blob end %zu.",
range.bytes, keys_[range.key_idx].c_str(),
static_cast<size_t>(range.End()),
static_cast<size_t>(blob_range.End()));
}
}
requests_.emplace_back(range, data);
}
// Parallel synchronous I/O. Alternatives considered:
// - readv is limited to 0x7FFFF000 bytes on Linux (even 64-bit). Note that
// pread calls preadv with a single iovec.
// TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX.
// - O_DIRECT seems undesirable because we do want to use the OS cache
// between consecutive runs.
void BlobReader::ReadAll(hwy::ThreadPool& pool) const {
PROFILER_ZONE("Startup.ReadAll");
HWY_ASSERT(!IsMapped());
// >5x speedup from parallel reads when cached.
pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) {
const BlobRange& range = requests_[i].range;
const uint64_t end = range.End();
const std::string& key = keys_[range.key_idx];
const BlobRange& blob_range = Range(range.key_idx);
HWY_ASSERT(blob_range.End() <= file_bytes_);
if (end > blob_range.End()) {
HWY_ABORT(
"Bug: want to read %zu bytes of %s until %zu, past blob end %zu.",
range.bytes, key.c_str(), static_cast<size_t>(end),
static_cast<size_t>(blob_range.End()));
}
if (!file_->Read(range.offset, range.bytes, requests_[i].data)) {
HWY_ABORT("Read failed for %s from %zu, %zu bytes to %p.", key.c_str(),
static_cast<size_t>(range.offset), range.bytes,
requests_[i].data);
}
});
}
// Decides whether to read or map the file.
static BlobReader::Mode ChooseMode(uint64_t file_mib, Tristate map) {
const Allocator& allocator = ThreadingContext::Get().allocator;
// User has explicitly requested a map or read via args.
if (map == Tristate::kTrue) return BlobReader::Mode::kMap;
if (map == Tristate::kFalse) return BlobReader::Mode::kRead;
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
// because idle memory is used as cache, so do not use it to decide.
const size_t total_mib = allocator.TotalMiB();
if (file_mib > total_mib) {
HWY_WARN("Weight file %zu MiB > detected memory %zu MiB.",
static_cast<size_t>(file_mib), total_mib);
}
// Large fraction of total.
if (file_mib >= total_mib / 3) return BlobReader::Mode::kMap;
// Big enough that even parallel loading wouldn't be quick.
if (file_mib > 50 * 1024) return BlobReader::Mode::kMap;
return BlobReader::Mode::kRead;
}
std::unique_ptr<BlobReader> BlobReader::Make(const Path& blob_path,
const Tristate map) {
if (blob_path.Empty()) HWY_ABORT("No --weights specified.");
std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r");
if (!file) HWY_ABORT("Failed to open file %s", blob_path.path.c_str());
const uint64_t file_bytes = file->FileSize();
if (file_bytes == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
// Even if `kMap`, read the directory via the `kRead` mode for simplicity.
BlobStore bs(*file);
if (!bs.IsValid(file_bytes)) {
return std::unique_ptr<BlobReader>(); // IsValid already printed a warning
}
return std::unique_ptr<BlobReader>(new BlobReader(
std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map)));
}
// Split into chunks for load-balancing even if blob sizes vary.

View File

@ -55,25 +55,19 @@ struct BlobIO2 {
class BlobStore;
// Reads `BlobStore` header, converts keys to strings and creates a hash map for
// faster lookups, and reads or maps blob data.
// Thread-safe: it is safe to concurrently call all methods except `Enqueue`,
// because they are const.
// TODO(janwas): split into header and reader/mapper classes.
// faster lookups.
// TODO(janwas): rename to BlobFinder or similar.
// Thread-safe: it is safe to concurrently call all methods.
class BlobReader {
public:
// Parallel I/O into allocated memory, or mapped view of file. The latter is
// better when the file is huge, but page faults add noise to measurements.
enum class Mode { kRead, kMap };
// Acquires ownership of `file` (which must be non-null) and reads its header.
// Factory function instead of ctor because this can fail (return null).
static std::unique_ptr<BlobReader> Make(const Path& blob_path,
Tristate map = Tristate::kDefault);
// Aborts on error.
explicit BlobReader(const Path& blob_path);
~BlobReader() = default;
// Returns true if the mode passed to ctor was `kMap` and mapping succeeded.
bool IsMapped() const { return mode_ == Mode::kMap; }
// Non-const version required for File::Map().
File& file() { return *file_; }
const File& file() const { return *file_; }
uint64_t file_bytes() const { return file_bytes_; }
const std::vector<std::string>& Keys() const { return keys_; }
@ -92,20 +86,8 @@ class BlobReader {
return &range;
}
// Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that
// everything else except `CallWithSpan` is in units of bytes.
template <typename T>
hwy::Span<const T> MappedSpan(const BlobRange& range) const {
HWY_ASSERT(IsMapped());
HWY_ASSERT(range.bytes % sizeof(T) == 0);
return hwy::Span<const T>(
HWY_RCAST_ALIGNED(const T*, mapped_.get() + range.offset),
range.bytes / sizeof(T));
}
// Returns error, or calls `func(span)` with the blob identified by `key`.
// This may allocate memory for the blob, and is intended for small blobs for
// which an aligned allocation is unnecessary.
// Allocates unaligned memory for the blob; intended for small metadata blobs.
template <typename T, class Func>
bool CallWithSpan(const std::string& key, const Func& func) const {
const BlobRange* range = Find(key);
@ -114,11 +96,6 @@ class BlobReader {
return false;
}
if (mode_ == Mode::kMap) {
func(MappedSpan<T>(*range));
return true;
}
HWY_ASSERT(range->bytes % sizeof(T) == 0);
std::vector<T> storage(range->bytes / sizeof(T));
if (!file_->Read(range->offset, range->bytes, storage.data())) {
@ -131,30 +108,13 @@ class BlobReader {
return true;
}
// The following methods must only be called if `!IsMapped()`.
// Enqueues a BlobIO2 for `ReadAll` to execute.
void Enqueue(const BlobRange& range, void* data);
// Reads in parallel all enqueued requests to the specified destinations.
// Aborts on error.
void ReadAll(hwy::ThreadPool& pool) const;
private:
// Only for use by `Make`.
BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, Mode mode);
const std::unique_ptr<File> file_;
const uint64_t file_bytes_;
Mode mode_;
std::vector<std::string> keys_;
std::vector<BlobRange> ranges_;
std::unordered_map<std::string, size_t> key_idx_for_key_;
MapPtr mapped_; // only if `kMap`
std::vector<BlobIO2> requests_; // only if `kRead`
};
// Collects references to blobs and writes them all at once with parallel I/O.

View File

@ -19,7 +19,6 @@
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <vector>
@ -36,7 +35,7 @@ namespace {
class BlobStoreTest : public testing::Test {};
#endif
void TestWithMapped(Tristate map) {
TEST(BlobStoreTest, TestReadWrite) {
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
@ -59,43 +58,30 @@ void TestWithMapped(Tristate map) {
std::fill(buffer.begin(), buffer.end(), 0);
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
HWY_ASSERT(reader);
const BlobReader reader(path);
HWY_ASSERT_EQ(reader->Keys().size(), 2);
HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str());
HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str());
HWY_ASSERT_EQ(reader.Keys().size(), 2);
HWY_ASSERT_STRING_EQ(reader.Keys()[0].c_str(), keyA.c_str());
HWY_ASSERT_STRING_EQ(reader.Keys()[1].c_str(), keyB.c_str());
const BlobRange* range = reader->Find(keyA);
const BlobRange* range = reader.Find(keyA);
HWY_ASSERT(range);
const uint64_t offsetA = range->offset;
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
HWY_ASSERT_EQ(range->bytes, 5);
range = reader->Find(keyB);
range = reader.Find(keyB);
HWY_ASSERT(range);
const uint64_t offsetB = range->offset;
HWY_ASSERT_EQ(offsetB, 2 * 256);
HWY_ASSERT_EQ(range->bytes, sizeof(buffer));
if (!reader->IsMapped()) {
char str[5];
reader->Enqueue(
BlobRange{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
reader->Enqueue(
BlobRange{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
buffer.data());
reader->ReadAll(pool);
HWY_ASSERT_STRING_EQ("DATA", str);
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
}
HWY_ASSERT(
reader->CallWithSpan<char>(keyA, [](const hwy::Span<const char> span) {
reader.CallWithSpan<char>(keyA, [](const hwy::Span<const char> span) {
HWY_ASSERT_EQ(span.size(), 5);
HWY_ASSERT_STRING_EQ("DATA", span.data());
}));
HWY_ASSERT(
reader->CallWithSpan<float>(keyB, [](const hwy::Span<const float> span) {
reader.CallWithSpan<float>(keyB, [](const hwy::Span<const float> span) {
HWY_ASSERT_EQ(span.size(), 4);
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), span.data(), span.size());
}));
@ -104,11 +90,6 @@ void TestWithMapped(Tristate map) {
unlink(path_str);
}
TEST(BlobStoreTest, TestReadWrite) {
TestWithMapped(Tristate::kFalse);
TestWithMapped(Tristate::kTrue);
}
// Ensures padding works for any number of random-sized blobs.
TEST(BlobStoreTest, TestNumBlobs) {
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
@ -143,17 +124,14 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(blobs.size() == num_blobs);
writer.WriteAll(pool, path);
const Tristate map = Tristate::kFalse;
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
HWY_ASSERT(reader);
HWY_ASSERT_EQ(reader->Keys().size(), num_blobs);
BlobReader reader(path);
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(),
std::to_string(i).c_str());
const BlobRange* range = reader->Find(keys[i]);
HWY_ASSERT_STRING_EQ(reader.Keys()[i].c_str(), std::to_string(i).c_str());
const BlobRange* range = reader.Find(keys[i]);
HWY_ASSERT(range);
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
HWY_ASSERT(reader->CallWithSpan<uint8_t>(
HWY_ASSERT(reader.CallWithSpan<uint8_t>(
keys[i], [path_str, num_blobs, i, range,
&blobs](const hwy::Span<const uint8_t> span) {
HWY_ASSERT_EQ(blobs[i].size(), span.size());

107
io/io.cc
View File

@ -19,20 +19,42 @@
// check this in source code because we support multiple build systems.
#if !HWY_OS_WIN
// Request POSIX 2008, including `pread()` and `posix_fadvise()`.
// Request POSIX 2008, including `pread()` and `posix_fadvise()`. This also
// implies `_POSIX_C_SOURCE`.
#if !defined(_XOPEN_SOURCE) || _XOPEN_SOURCE < 700
#undef _XOPEN_SOURCE
#define _XOPEN_SOURCE 700
#endif
#if !defined(_POSIX_C_SOURCE) || _POSIX_C_SOURCE < 200809
#define _POSIX_C_SOURCE 200809
#define _XOPEN_SOURCE 700 // SUSv4
#endif
// Make `off_t` 64-bit even on 32-bit systems. Works for Android >= r15c.
#undef _FILE_OFFSET_BITS
#define _FILE_OFFSET_BITS 64
#if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \
(!defined(__ANDROID_API__) || __ANDROID_API__ >= 24)
#define GEMMA_IO_PREADV 1
#else
#define GEMMA_IO_PREADV 0
#endif
#if (HWY_OS_LINUX || HWY_OS_FREEBSD) && \
(!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
#define GEMMA_IO_FADVISE 1
#else
#define GEMMA_IO_FADVISE 0
#endif
#if GEMMA_IO_PREADV
// Replacement for the _BSD_SOURCE specified by preadv documentation.
#ifndef _DEFAULT_SOURCE
#define _DEFAULT_SOURCE
#endif
#include <errno.h>
#include <sys/uio.h> // preadv
#endif // GEMMA_IO_PREADV
#include <fcntl.h> // open
#include <limits.h> // IOV_MAX
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // SEEK_END - unistd isn't enough for IDE.
@ -43,6 +65,7 @@
#include <unistd.h> // read, write, close
#include <memory>
#include <string>
#include "io/io.h"
#include "util/allocator.h"
@ -119,6 +142,8 @@ class FilePosix : public File {
HWY_ASSERT(munmap(ptr, mapping_size) == 0);
}));
}
int Handle() const override { return fd_; }
}; // FilePosix
HWY_MAYBE_UNUSED extern std::unique_ptr<File> OpenFileGoogle(
@ -133,15 +158,83 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
const int fd = open(filename.path.c_str(), flags, 0644);
if (fd < 0) return file;
#if HWY_OS_LINUX && (!defined(__ANDROID_API__) || __ANDROID_API__ >= 21)
#if GEMMA_IO_FADVISE
if (is_read) {
// Doubles the readahead window, which seems slightly faster when cached.
(void)posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL);
}
#endif
#endif // GEMMA_IO_FADVISE
return std::make_unique<FilePosix>(fd);
}
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
if (!file) {
HWY_ABORT("Failed to open %s", filename.path.c_str());
}
return file;
}
std::string ReadFileToString(const Path& path) {
std::unique_ptr<File> file = OpenFileOrAbort(path, "r");
const size_t size = file->FileSize();
if (size == 0) {
HWY_ABORT("Empty file %s", path.path.c_str());
}
std::string content(size, ' ');
if (!file->Read(0, size, content.data())) {
HWY_ABORT("Failed to read %s", path.path.c_str());
}
return content;
}
#ifdef IOV_MAX
constexpr size_t kMaxSpans = IOV_MAX;
#else
constexpr size_t kMaxSpans = 1024; // Linux limit
#endif
IOBatch::IOBatch(uint64_t offset, size_t key_idx)
: offset_(offset), key_idx_(key_idx) {
spans_.reserve(kMaxSpans);
}
// Returns true if the batch was full; if so, call again on the new batch.
bool IOBatch::Add(void* mem, size_t bytes) {
if (spans_.size() >= kMaxSpans) return false;
if (total_bytes_ + bytes > 0x7FFFF000) return false; // Linux limit
spans_.push_back({.mem = mem, .bytes = bytes});
total_bytes_ += bytes;
return true;
}
uint64_t IOBatch::Read(const File& file) const {
#if GEMMA_IO_PREADV
HWY_ASSERT(!spans_.empty());
ssize_t bytes_read;
for (;;) {
bytes_read =
preadv(file.Handle(), reinterpret_cast<const iovec*>(spans_.data()),
static_cast<int>(spans_.size()), offset_);
if (bytes_read >= 0) break;
if (errno == EINTR) continue; // signal: retry
HWY_WARN("preadv failed, errno %d.", errno);
return 0;
}
return static_cast<uint64_t>(bytes_read);
#else
uint64_t total = 0;
uint64_t offset = offset_;
for (const IOSpan& span : spans_) {
if (!file.Read(offset, span.bytes, span.mem)) return 0;
total += span.bytes;
offset += span.bytes;
}
return total;
#endif
}
} // namespace gcpp
#endif // !HWY_OS_WIN

63
io/io.h
View File

@ -22,6 +22,7 @@
#include <memory>
#include <string>
#include <utility> // std::move
#include <vector>
#include "util/allocator.h"
#include "hwy/base.h"
@ -49,22 +50,66 @@ class File {
virtual uint64_t FileSize() const = 0;
// Returns true if all the requested bytes were read.
// Thread-compatible.
virtual bool Read(uint64_t offset, uint64_t size, void* to) const = 0;
// Returns true if all the requested bytes were written.
// Thread-compatible.
virtual bool Write(const void* from, uint64_t size, uint64_t offset) = 0;
// Maps the entire file into read-only memory or returns nullptr on failure.
// We do not support offsets because Windows requires them to be a multiple of
// the allocation granularity, which is 64 KiB. Some implementations may fail
// if the file is zero-sized and return a nullptr.
// if the file is zero-sized and return a nullptr. Non-const because it may
// modify internal state. This is only expected to be called once per file.
virtual MapPtr Map() = 0;
// For use by `IOBatch::Read`.
virtual int Handle() const { return -1; }
};
// Returns nullptr on failure. `mode` is either "r" or "w+". This is not just
// named 'OpenFile' to avoid a conflict with Windows.h #define.
std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode);
// As above, but aborts on instead of returning nullptr.
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode);
// Compatible with Linux iovec.
struct IOSpan {
void* mem;
size_t bytes;
};
// Wrapper for Linux/BSD `preadv`, calling `File::Read` on other systems. To
// insert row padding, we previously issued one IO per tensor row, which is
// expensive. `preadv` reduces up to 1024 syscalls to 1.
// The file data must be contiguous starting from `IOBatch::offset_`, because
// `preadv` does not support per-`IOSpan` offsets.
class IOBatch {
public:
// Reserves memory in `spans_`. `key_idx` identifies the blob/tensor.
explicit IOBatch(uint64_t offset, size_t key_idx);
// The next `bytes` will be read from file into `mem`.
// Returns true if the batch was full; if so, call again on the new batch.
bool Add(void* mem, size_t bytes);
uint64_t Offset() const { return offset_; }
uint64_t TotalBytes() const { return total_bytes_; }
size_t KeyIdx() const { return key_idx_; }
// Returns the total number of bytes read, or 0 if any I/O failed.
// Thread-compatible.
uint64_t Read(const File& file) const;
private:
uint64_t offset_;
uint64_t total_bytes_ = 0;
size_t key_idx_;
std::vector<IOSpan> spans_; // contiguous in the file.
};
// Wrapper for strings representing a path name. Differentiates vs. arbitrary
// strings and supports shortening for display purposes.
struct Path {
@ -97,21 +142,7 @@ struct Path {
};
// Aborts on error.
static inline HWY_MAYBE_UNUSED std::string ReadFileToString(const Path& path) {
std::unique_ptr<File> file = OpenFileOrNull(path, "r");
if (!file) {
HWY_ABORT("Failed to open %s", path.path.c_str());
}
const size_t size = file->FileSize();
if (size == 0) {
HWY_ABORT("Empty file %s", path.path.c_str());
}
std::string content(size, ' ');
if (!file->Read(0, size, content.data())) {
HWY_ABORT("Failed to read %s", path.path.c_str());
}
return content;
}
std::string ReadFileToString(const Path& path);
} // namespace gcpp