Write SBS progressively.

(1) Directly write to file in BlobWriter::Add and destruct the MatOwner to release the rams.

(2) Write a fake header to indicate this is V2, and write correct header and directory at the end of the file.

(3) Tested on loading sbs written the old way, and new way, both worked.

PiperOrigin-RevId: 789306837
This commit is contained in:
Charles Zhao 2025-07-31 06:05:02 -07:00 committed by Copybara-Service
parent 8715eda512
commit 50ee1a3e92
10 changed files with 255 additions and 169 deletions

View File

@ -75,8 +75,8 @@ class SbsWriterImpl : public ISbsWriter {
} }
mat.AppendTo(serialized_mat_ptrs_); mat.AppendTo(serialized_mat_ptrs_);
mat_owners_.push_back(MatOwner()); MatOwner mat_owner;
mat_owners_.back().AllocateFor(mat, ctx_.allocator, MatPadding::kPacked); mat_owner.AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
// Handle gemma_export_test's MockArray. Write blobs so that the test // Handle gemma_export_test's MockArray. Write blobs so that the test
// succeeds, but we only have 10 floats, not the full tensor. // succeeds, but we only have 10 floats, not the full tensor.
@ -97,7 +97,9 @@ class SbsWriterImpl : public ISbsWriter {
} }
public: public:
SbsWriterImpl() : ctx_(ThreadingArgs()) {} SbsWriterImpl(const std::string& sbs_path)
: ctx_(ThreadingArgs()),
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
void Insert(const char* name, F32Span weights, Type type, void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override { const TensorInfo& tensor_info) override {
@ -120,23 +122,23 @@ class SbsWriterImpl : public ISbsWriter {
} }
} }
void Write(const ModelConfig& config, const std::string& tokenizer_path, void Write(const ModelConfig& config,
const std::string& path) override { const std::string& tokenizer_path) override {
const GemmaTokenizer tokenizer( const GemmaTokenizer tokenizer(
tokenizer_path.empty() ? kMockTokenizer tokenizer_path.empty() ? kMockTokenizer
: ReadFileToString(Path(tokenizer_path))); : ReadFileToString(Path(tokenizer_path)));
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_);
ctx_.pools.Pool(), gcpp::Path(path));
} }
ThreadingContext ctx_; ThreadingContext ctx_;
std::vector<MatOwner> mat_owners_;
CompressWorkingSet working_set_; CompressWorkingSet working_set_;
BlobWriter writer_; BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_; std::vector<uint32_t> serialized_mat_ptrs_;
}; };
ISbsWriter* NewSbsWriter() { return new SbsWriterImpl(); } ISbsWriter* NewSbsWriter(const std::string& sbs_path) {
return new SbsWriterImpl(sbs_path);
}
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp
@ -147,7 +149,8 @@ namespace gcpp {
HWY_EXPORT(NewSbsWriter); HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {} SbsWriter::SbsWriter(const std::string& path)
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {}
SbsReader::SbsReader(const std::string& path) SbsReader::SbsReader(const std::string& path)
: reader_(Path(path)), model_(reader_) {} : reader_(Path(path)), model_(reader_) {}

View File

@ -44,24 +44,22 @@ class ISbsWriter {
const TensorInfo& tensor_info) = 0; const TensorInfo& tensor_info) = 0;
virtual void Write(const ModelConfig& config, virtual void Write(const ModelConfig& config,
const std::string& tokenizer_path, const std::string& tokenizer_path) = 0;
const std::string& path) = 0;
}; };
// Non-virtual class used by pybind that calls the interface's virtual methods. // Non-virtual class used by pybind that calls the interface's virtual methods.
// This avoids having to register the derived types with pybind. // This avoids having to register the derived types with pybind.
class SbsWriter { class SbsWriter {
public: public:
SbsWriter(); explicit SbsWriter(const std::string& sbs_path);
void Insert(const char* name, F32Span weights, Type type, void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) { const TensorInfo& tensor_info) {
impl_->Insert(name, weights, type, tensor_info); impl_->Insert(name, weights, type, tensor_info);
} }
void Write(const ModelConfig& config, const std::string& tokenizer_path, void Write(const ModelConfig& config, const std::string& tokenizer_path) {
const std::string& path) { impl_->Write(config, tokenizer_path);
impl_->Write(config, tokenizer_path, path);
} }
private: private:

View File

@ -44,10 +44,9 @@ static void CallWithF32Span(SbsWriter& writer, const char* name,
PYBIND11_MODULE(compression, m) { PYBIND11_MODULE(compression, m) {
class_<SbsWriter>(m, "SbsWriter") class_<SbsWriter>(m, "SbsWriter")
.def(init<>()) .def(init<std::string>())
.def("insert", CallWithF32Span<&SbsWriter::Insert>) .def("insert", CallWithF32Span<&SbsWriter::Insert>)
.def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"), .def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"));
arg("path"));
class_<MatPtr>(m, "MatPtr") class_<MatPtr>(m, "MatPtr")
// No init, only created within C++. // No init, only created within C++.

View File

@ -30,7 +30,8 @@ class CompressionTest(absltest.TestCase):
info_192.axes = [0] info_192.axes = [0]
info_192.shape = [192] info_192.shape = [192]
writer = compression.SbsWriter() temp_file = self.create_tempfile("test.sbs")
writer = compression.SbsWriter(temp_file.full_path)
writer.insert( writer.insert(
"tensor0", "tensor0",
# Large enough to require scaling. # Large enough to require scaling.
@ -95,8 +96,7 @@ class CompressionTest(absltest.TestCase):
configs.PromptWrapping.GEMMA_IT, configs.PromptWrapping.GEMMA_IT,
) )
tokenizer_path = "" # no tokenizer required for testing tokenizer_path = "" # no tokenizer required for testing
temp_file = self.create_tempfile("test.sbs") writer.write(config, tokenizer_path)
writer.write(config, tokenizer_path, temp_file.full_path)
print("Ignore next two warnings; test does not enable model deduction.") print("Ignore next two warnings; test does not enable model deduction.")
reader = compression.SbsReader(temp_file.full_path) reader = compression.SbsReader(temp_file.full_path)

View File

@ -618,11 +618,11 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
Gemma::~Gemma() = default; Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, NestedPools& pools) const { void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
BlobWriter writer; BlobWriter writer(weights_path, pools.Pool());
const std::vector<uint32_t> serialized_mat_ptrs = const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer); weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
writer, pools.Pool(), weights_path); writer);
} }
void Gemma::Generate(const RuntimeConfig& runtime_config, void Gemma::Generate(const RuntimeConfig& runtime_config,

View File

@ -444,8 +444,7 @@ static void AddBlob(const char* name, const std::vector<uint32_t>& data,
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs, const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer, hwy::ThreadPool& pool, BlobWriter& writer) {
const Path& path) {
HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown); HWY_ASSERT(config.weight != Type::kUnknown);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
@ -459,7 +458,7 @@ void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
AddBlob(kMatPtrsName, serialized_mat_ptrs, writer); AddBlob(kMatPtrsName, serialized_mat_ptrs, writer);
writer.WriteAll(pool, path); writer.WriteAll();
} }
} // namespace gcpp } // namespace gcpp

View File

@ -105,8 +105,7 @@ class ModelStore {
// produces a single BlobStore file holding everything required for inference. // produces a single BlobStore file holding everything required for inference.
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs, const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer, hwy::ThreadPool& pool, BlobWriter& writer);
const Path& path);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_

View File

@ -86,16 +86,26 @@ static_assert(sizeof(Header) == 16);
// A write I/O request, each serviced by one thread in a pool. // A write I/O request, each serviced by one thread in a pool.
struct BlobIO { struct BlobIO {
BlobIO(BlobRange range, void* data) : range(range), data(data) {} BlobIO(BlobRange range, const void* data) : range(range), data(data) {}
BlobRange range; BlobRange range;
void* data; // Read-only for writes. const void* data; // Read-only for writes.
}; };
// Little-endian on-disk representation: a fixed-size `Header`, then a padded // Little-endian on-disk representation:
// variable-length 'directory' of blob keys and their offset/sizes, then the // For V1: the file is represented as
// 'payload' of each blob's data with padding in between, followed by padding to // Header + Directory + PadToBlobAlign + Payload + PayToEndAlign.
// `kEndAlign`. Keys are unique, opaque 128-bit keys. // For V2: the file is represented as
// Header + PadToBlobAlign + Payload + PadToEndAlign + Directory + Header
// The Header at the beginning has num_blobs == 0; and the Header at the end has
// the correct num_blobs.
//
// Actual payload is indexed by the directory with keys, offset and bytes; keys
// are unique, opaque 128-bit keys.
//
// The file format deliberately omits a version number because it is unchanging.
// Additional data may be added only inside new blobs. Changes to the blob
// contents or type should be handled by renaming keys.
// //
// The file format deliberately omits a version number because it is unchanging. // The file format deliberately omits a version number because it is unchanging.
// Additional data may be added only inside new blobs. Changes to the blob // Additional data may be added only inside new blobs. Changes to the blob
@ -106,23 +116,22 @@ struct BlobIO {
class BlobStore { class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
// Arbitrary upper limit to avoid allocating a huge vector. // Upper limit to avoid allocating a huge vector.
static constexpr size_t kMaxBlobs = 64 * 1024; static constexpr size_t kMaxBlobs = 16 * 1024;
// Returns the end of the directory, including padding, which is also the // Returns the size of padded header and directory, which is also the start of
// start of the first payload. `num_blobs` is `NumBlobs()` if the header is // the first payload for V1. `num_blobs` is `NumBlobs()` if the header is
// already available, otherwise the number of blobs to be written. // already available, otherwise the number of blobs to be written.
static size_t PaddedDirEnd(size_t num_blobs) { static size_t PaddedHeaderAndDirBytes(size_t num_blobs) {
HWY_ASSERT(num_blobs < kMaxBlobs); HWY_ASSERT(num_blobs < kMaxBlobs);
// Per blob, a key and offset/size. // Per blob, a key and offset/size.
return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs); return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs);
} }
static uint64_t PaddedPayloadBytes(size_t num_blobs, static uint64_t PaddedPayloadBytes(const std::vector<size_t>& blob_sizes) {
const hwy::Span<const uint8_t> blobs[]) {
uint64_t total_payload_bytes = 0; uint64_t total_payload_bytes = 0;
for (size_t i = 0; i < num_blobs; ++i) { for (size_t blob_size : blob_sizes) {
total_payload_bytes += RoundUpToAlign(blobs[i].size()); total_payload_bytes += RoundUpToAlign(blob_size);
} }
// Do not round up to `kEndAlign` because the padding also depends on the // Do not round up to `kEndAlign` because the padding also depends on the
// directory size. Here we only count the payload. // directory size. Here we only count the payload.
@ -136,6 +145,72 @@ class BlobStore {
} }
} }
bool ParseHeaderAndDirectoryV1(const File& file) {
is_file_v2_ = false;
// Read header from the beginning of the file.
if (!file.Read(0, sizeof(header_), &header_)) {
HWY_WARN("Failed to read BlobStore header.");
return false;
}
if (header_.magic != kMagic) {
HWY_WARN("BlobStore header magic %08x does not match %08x.",
header_.magic, kMagic);
return false;
}
if (header_.num_blobs == 0) {
// Should parse as V2.
return false;
}
if (header_.num_blobs > kMaxBlobs) {
HWY_WARN("Too many blobs, likely corrupt file.");
return false;
}
directory_.resize(header_.num_blobs * 2);
const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs;
// Read directory after the header.
if (!file.Read(sizeof(header_), directory_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
return false;
}
HWY_ASSERT(IsValid(file.FileSize()));
return true;
}
bool ParseHeaderAndDirectoryV2(const File& file) {
is_file_v2_ = true;
// Read header from the end of the file.
size_t offset = file.FileSize() - sizeof(header_);
if (!file.Read(offset, sizeof(header_), &header_)) {
HWY_WARN("Failed to read BlobStore header.");
return false;
}
if (header_.magic != kMagic) {
HWY_WARN("BlobStore header magic %08x does not match %08x.",
header_.magic, kMagic);
return false;
}
if (header_.num_blobs > kMaxBlobs) {
HWY_WARN("Too many blobs, likely corrupt file.");
return false;
}
directory_.resize(header_.num_blobs * 2);
const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs;
offset -= directory_bytes;
// Read directory immediately before the header.
if (!file.Read(offset, directory_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
return false;
}
HWY_ASSERT(IsValid(file.FileSize()));
return true;
}
public: public:
template <typename T> template <typename T>
static T RoundUpToAlign(T size_or_offset) { static T RoundUpToAlign(T size_or_offset) {
@ -144,60 +219,46 @@ class BlobStore {
// Reads header/directory from file. // Reads header/directory from file.
explicit BlobStore(const File& file) { explicit BlobStore(const File& file) {
if (!file.Read(0, sizeof(header_), &header_)) { if (ParseHeaderAndDirectoryV1(file)) {
HWY_WARN("Failed to read BlobStore header.");
return; return;
} }
// Avoid allocating a huge vector. if (ParseHeaderAndDirectoryV2(file)) {
if (header_.num_blobs >= kMaxBlobs) {
HWY_WARN("Too many blobs, likely corrupt file.");
return;
}
const size_t padded_dir_end = PaddedDirEnd(NumBlobs());
const size_t padded_dir_bytes = padded_dir_end - sizeof(header_);
HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0);
directory_.resize(padded_dir_bytes / kU128Bytes);
if (!file.Read(sizeof(header_), padded_dir_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
return; return;
} }
HWY_ABORT("Failed to read BlobStore header or directory.");
} }
// Initializes header/directory for writing to disk. // Initializes header/directory for writing to disk.
BlobStore(size_t num_blobs, const hwy::uint128_t keys[], BlobStore(const std::vector<hwy::uint128_t>& keys,
const hwy::Span<const uint8_t> blobs[]) { const std::vector<size_t>& blob_sizes) {
const size_t num_blobs = keys.size();
HWY_ASSERT(num_blobs < kMaxBlobs); // Ensures safe to cast to u32. HWY_ASSERT(num_blobs < kMaxBlobs); // Ensures safe to cast to u32.
HWY_ASSERT(keys && blobs); HWY_ASSERT(keys.size() == blob_sizes.size());
EnsureUnique(hwy::Span<const hwy::uint128_t>(keys, num_blobs)); EnsureUnique(hwy::Span<const hwy::uint128_t>(keys.data(), num_blobs));
uint64_t offset = PaddedDirEnd(num_blobs);
const size_t padded_dir_bytes =
static_cast<size_t>(offset) - sizeof(header_);
// Set header_.
header_.magic = kMagic; header_.magic = kMagic;
header_.num_blobs = static_cast<uint32_t>(num_blobs); header_.num_blobs = static_cast<uint32_t>(num_blobs);
header_.file_bytes = hwy::RoundUpTo(
offset + PaddedPayloadBytes(num_blobs, blobs), kEndAlign);
HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0); const size_t size_before_blobs = BytesBeforeBlobsV2().size();
directory_.resize(padded_dir_bytes / kU128Bytes); header_.file_bytes =
hwy::CopyBytes(keys, directory_.data(), num_blobs * kU128Bytes); hwy::RoundUpTo(size_before_blobs + PaddedPayloadBytes(blob_sizes) +
PaddedHeaderAndDirBytes(num_blobs),
kEndAlign);
// Set first num_blobs elements of directory_ which are the keys.
directory_.resize(num_blobs * 2);
hwy::CopyBytes(keys.data(), directory_.data(), num_blobs * kU128Bytes);
EnsureUnique(Keys()); EnsureUnique(Keys());
// `SetRange` below will fill `directory_[num_blobs, 2 * num_blobs)`.
hwy::ZeroBytes(directory_.data() + 2 * num_blobs,
padded_dir_bytes - 2 * num_blobs * kU128Bytes);
// We already zero-initialized the directory padding; // Set the second half of directory_ which is the offsets and sizes.
// `BlobWriter::WriteAll` takes care of padding after each blob via an uint64_t offset = size_before_blobs;
// additional I/O.
for (size_t i = 0; i < num_blobs; ++i) { for (size_t i = 0; i < num_blobs; ++i) {
HWY_ASSERT(blobs[i].data() != nullptr); SetRange(i, offset, blob_sizes[i]);
SetRange(i, offset, blobs[i].size()); offset = RoundUpToAlign(offset + blob_sizes[i]);
offset = RoundUpToAlign(offset + blobs[i].size());
} }
// When writing new files, we always pad to `kEndAlign`.
HWY_ASSERT(hwy::RoundUpTo(offset, kEndAlign) == header_.file_bytes); HWY_ASSERT(IsValid(FileSize()));
} }
// Must be checked by readers before other methods. // Must be checked by readers before other methods.
@ -221,7 +282,10 @@ class BlobStore {
} }
// Ensure blobs are back to back. // Ensure blobs are back to back.
uint64_t expected_offset = PaddedDirEnd(NumBlobs()); const size_t size_before_blobs = BytesBeforeBlobs().size();
const size_t size_after_blobs = BytesAfterBlobs().size();
uint64_t expected_offset = size_before_blobs;
for (size_t key_idx = 0; key_idx < NumBlobs(); ++key_idx) { for (size_t key_idx = 0; key_idx < NumBlobs(); ++key_idx) {
uint64_t actual_offset; uint64_t actual_offset;
size_t bytes; size_t bytes;
@ -236,7 +300,7 @@ class BlobStore {
} }
// Previously files were not padded to `kEndAlign`, so also allow that. // Previously files were not padded to `kEndAlign`, so also allow that.
if (expected_offset != header_.file_bytes && if (expected_offset != header_.file_bytes &&
hwy::RoundUpTo(expected_offset, kEndAlign) != header_.file_bytes) { expected_offset + size_after_blobs != header_.file_bytes) {
HWY_WARN("Invalid BlobStore: end of blobs %zu but file size %zu.", HWY_WARN("Invalid BlobStore: end of blobs %zu but file size %zu.",
static_cast<size_t>(expected_offset), static_cast<size_t>(expected_offset),
static_cast<size_t>(header_.file_bytes)); static_cast<size_t>(header_.file_bytes));
@ -246,20 +310,71 @@ class BlobStore {
return true; // all OK return true; // all OK
} }
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO>& writes) const { static std::vector<uint8_t> BytesBeforeBlobsV2() {
const size_t key_idx = 0; // not actually associated with a key/blob const Header kFakeHeaderV2 = {
writes.emplace_back( .magic = kMagic,
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx}, .num_blobs = 0,
// members are const and BlobIO requires non-const pointers, and they .file_bytes = kEndAlign,
// are not modified by file writes. };
const_cast<Header*>(&header_)); std::vector<uint8_t> header(PaddedHeaderAndDirBytes(0));
writes.emplace_back( hwy::CopyBytes(&kFakeHeaderV2, header.data(), sizeof(Header));
BlobRange{.offset = sizeof(header_), return header;
.bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_),
.key_idx = key_idx},
const_cast<hwy::uint128_t*>(directory_.data()));
} }
std::vector<uint8_t> BytesBeforeBlobs() const {
if (is_file_v2_) {
return BytesBeforeBlobsV2();
} else {
const size_t padded_header_and_directory_size =
PaddedHeaderAndDirBytes(NumBlobs());
std::vector<uint8_t> header_and_directory(
padded_header_and_directory_size);
// Copy header_ at the beginning (offset 0)
hwy::CopyBytes(&header_, header_and_directory.data(), sizeof(header_));
// Copy directory_ immediately after the header_
hwy::CopyBytes(directory_.data(),
header_and_directory.data() + sizeof(header_),
2 * kU128Bytes * NumBlobs());
return header_and_directory;
}
}
std::vector<uint8_t> BytesAfterBlobs() const {
// Gets blob end.
uint64_t offset = 0;
size_t bytes = 0;
GetRange(NumBlobs() - 1, offset, bytes);
const uint64_t blob_end = RoundUpToAlign(offset + bytes);
// For V1, just return the file paddings.
if (!is_file_v2_) {
return std::vector<uint8_t>(FileSize() - blob_end);
}
const size_t header_and_directory_with_file_padding_size =
FileSize() - blob_end;
std::vector<uint8_t> header_and_directory(
header_and_directory_with_file_padding_size);
const size_t header_size = sizeof(Header);
const size_t directory_size = 2 * kU128Bytes * NumBlobs();
// Copy header_ at the end.
offset = header_and_directory_with_file_padding_size - header_size;
hwy::CopyBytes(&header_, header_and_directory.data() + offset, header_size);
// Copy directory_ immediately before the header_.
offset -= directory_size;
hwy::CopyBytes(directory_.data(), header_and_directory.data() + offset,
directory_size);
return header_and_directory;
}
size_t FileSize() const { return header_.file_bytes; }
size_t NumBlobs() const { return static_cast<size_t>(header_.num_blobs); } size_t NumBlobs() const { return static_cast<size_t>(header_.num_blobs); }
// Not the entirety of `directory_`! The second half is offset/size. // Not the entirety of `directory_`! The second half is offset/size.
@ -291,6 +406,8 @@ class BlobStore {
val.hi = bytes; val.hi = bytes;
} }
bool is_file_v2_ = true;
Header header_; Header header_;
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`. std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
@ -303,7 +420,6 @@ BlobReader::BlobReader(const Path& blob_path)
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str()); if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
BlobStore bs(*file_); BlobStore bs(*file_);
HWY_ASSERT(bs.IsValid(file_bytes_)); // IsValid already printed a warning
keys_.reserve(bs.NumBlobs()); keys_.reserve(bs.NumBlobs());
for (const hwy::uint128_t key : bs.Keys()) { for (const hwy::uint128_t key : bs.Keys()) {
@ -324,8 +440,8 @@ BlobReader::BlobReader(const Path& blob_path)
// Split into chunks for load-balancing even if blob sizes vary. // Split into chunks for load-balancing even if blob sizes vary.
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
uint8_t* data, std::vector<BlobIO>& writes) { const uint8_t* data, std::vector<BlobIO>& writes) {
constexpr size_t kChunkBytes = 4 * 1024 * 1024; constexpr size_t kChunkBytes = 10 * 1024 * 1024;
const uint64_t end = offset + bytes; const uint64_t end = offset + bytes;
// Split into whole chunks and possibly one remainder. // Split into whole chunks and possibly one remainder.
if (end >= kChunkBytes) { if (end >= kChunkBytes) {
@ -341,83 +457,42 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx}, BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
data); data);
} }
// Write a padding if necessary.
static constexpr uint8_t kZeros[kBlobAlign] = {0};
const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes;
if (padding > 0) {
writes.emplace_back(
BlobRange{.offset = end, .bytes = padding, .key_idx = key_idx},
static_cast<const uint8_t*>(kZeros));
}
} }
static void EnqueueWritesForBlobs(const BlobStore& bs, BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool)
const hwy::Span<const uint8_t> blobs[], : file_(OpenFileOrNull(filename, "w+")), pool_(pool) {
std::vector<uint8_t>& zeros, if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
std::vector<BlobIO>& writes) { // Write a fake header to the beginning of the file.
// All-zero buffer used to write padding to the file without copying the std::vector<uint8_t> bytes_before_blobs = BlobStore::BytesBeforeBlobsV2();
// input blobs. file_->Write(bytes_before_blobs.data(), bytes_before_blobs.size(), 0);
static constexpr uint8_t kZeros[kBlobAlign] = {0};
uint64_t file_end = 0; // for padding
for (size_t key_idx = 0; key_idx < bs.NumBlobs(); ++key_idx) {
// We know the size, but `BlobStore` tells us the offset to write each blob.
uint64_t offset;
size_t bytes;
bs.GetRange(key_idx, offset, bytes);
HWY_ASSERT(offset != 0);
HWY_ASSERT(bytes == blobs[key_idx].size());
const uint64_t new_file_end = offset + bytes;
HWY_ASSERT(new_file_end >= file_end); // blobs are ordered by offset
file_end = new_file_end;
EnqueueChunks(key_idx, offset, bytes,
const_cast<uint8_t*>(blobs[key_idx].data()), writes);
const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes;
if (padding != 0) {
HWY_ASSERT(padding <= kBlobAlign);
writes.emplace_back(
BlobRange{
.offset = offset + bytes, .bytes = padding, .key_idx = key_idx},
const_cast<uint8_t*>(kZeros));
}
}
const size_t padding = hwy::RoundUpTo(file_end, kEndAlign) - file_end;
if (padding != 0) {
// Bigger than `kZeros`, better to allocate than issue multiple I/Os. Must
// remain alive until the last I/O is done.
zeros.resize(padding);
writes.emplace_back(
BlobRange{.offset = file_end, .bytes = padding, .key_idx = 0},
zeros.data());
}
} }
void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) { void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
HWY_ASSERT(data != nullptr); HWY_ASSERT(data != nullptr);
HWY_ASSERT(bytes != 0); HWY_ASSERT(bytes != 0);
keys_.push_back(KeyFromString(key.c_str())); keys_.push_back(KeyFromString(key.c_str()));
blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes); blob_sizes_.push_back(bytes);
}
void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
const size_t num_blobs = keys_.size();
HWY_ASSERT(num_blobs != 0);
HWY_ASSERT(num_blobs == blobs_.size());
std::vector<BlobIO> writes; std::vector<BlobIO> writes;
writes.reserve(16384); EnqueueChunks(keys_.size() - 1, file_->FileSize(), bytes,
static_cast<const uint8_t*>(data), writes);
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());
bs.EnqueueWriteForHeaderAndDirectory(writes);
std::vector<uint8_t> zeros;
EnqueueWritesForBlobs(bs, blobs_.data(), zeros, writes);
// Create/replace existing file.
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
if (!file) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
hwy::ThreadPool null_pool(0); hwy::ThreadPool null_pool(0);
hwy::ThreadPool& pool_or_serial = file->IsAppendOnly() ? null_pool : pool; hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_;
pool_or_serial.Run( pool_or_serial.Run(
0, writes.size(), [this, &file, &writes](uint64_t i, size_t /*thread*/) { 0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range; const BlobRange& range = writes[i].range;
if (!file->Write(writes[i].data, range.bytes, range.offset)) { if (!file_->Write(writes[i].data, range.bytes, range.offset)) {
const std::string& key = StringFromKey(keys_[range.key_idx]); const std::string& key = StringFromKey(keys_[range.key_idx]);
HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.", HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.",
key.c_str(), static_cast<size_t>(range.offset), range.bytes, key.c_str(), static_cast<size_t>(range.offset), range.bytes,
@ -426,4 +501,13 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
}); });
} }
void BlobWriter::WriteAll() {
const BlobStore bs = BlobStore(keys_, blob_sizes_);
// Write the rest of the bytes, which contains: paddings + directory + header.
const auto bytes_after_blobs = bs.BytesAfterBlobs();
file_->Write(bytes_after_blobs.data(), bytes_after_blobs.size(),
file_->FileSize());
}
} // namespace gcpp } // namespace gcpp

View File

@ -116,6 +116,8 @@ class BlobReader {
// does not make sense to call the methods concurrently. // does not make sense to call the methods concurrently.
class BlobWriter { class BlobWriter {
public: public:
explicit BlobWriter(const Path& filename, hwy::ThreadPool& pool);
void Add(const std::string& key, const void* data, size_t bytes); void Add(const std::string& key, const void* data, size_t bytes);
// For `ModelStore`: this is the `key_idx` of the next blob to be added. // For `ModelStore`: this is the `key_idx` of the next blob to be added.
@ -123,11 +125,13 @@ class BlobWriter {
// Stores all blobs to disk in the given order with padding for alignment. // Stores all blobs to disk in the given order with padding for alignment.
// Aborts on error. // Aborts on error.
void WriteAll(hwy::ThreadPool& pool, const Path& filename); void WriteAll();
private: private:
std::unique_ptr<File> file_;
std::vector<hwy::uint128_t> keys_; std::vector<hwy::uint128_t> keys_;
std::vector<hwy::Span<const uint8_t>> blobs_; std::vector<size_t> blob_sizes_;
hwy::ThreadPool& pool_;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -52,10 +52,10 @@ TEST(BlobStoreTest, TestReadWrite) {
const std::string keyA("0123456789abcdef"); // max 16 characters const std::string keyA("0123456789abcdef"); // max 16 characters
const std::string keyB("q"); const std::string keyB("q");
BlobWriter writer; BlobWriter writer(path, pool);
writer.Add(keyA, "DATA", 5); writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer)); writer.Add(keyB, buffer.data(), sizeof(buffer));
writer.WriteAll(pool, path); writer.WriteAll();
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size()); HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
std::fill(buffer.begin(), buffer.end(), 0); std::fill(buffer.begin(), buffer.end(), 0);
@ -69,12 +69,12 @@ TEST(BlobStoreTest, TestReadWrite) {
const BlobRange* range = reader.Find(keyA); const BlobRange* range = reader.Find(keyA);
HWY_ASSERT(range); HWY_ASSERT(range);
const uint64_t offsetA = range->offset; const uint64_t offsetA = range->offset;
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign HWY_ASSERT_EQ(offsetA, 256);
HWY_ASSERT_EQ(range->bytes, 5); HWY_ASSERT_EQ(range->bytes, 5);
range = reader.Find(keyB); range = reader.Find(keyB);
HWY_ASSERT(range); HWY_ASSERT(range);
const uint64_t offsetB = range->offset; const uint64_t offsetB = range->offset;
HWY_ASSERT_EQ(offsetB, 2 * 256); HWY_ASSERT_EQ(offsetB, offsetA + 256);
HWY_ASSERT_EQ(range->bytes, sizeof(buffer)); HWY_ASSERT_EQ(range->bytes, sizeof(buffer));
HWY_ASSERT( HWY_ASSERT(
@ -106,7 +106,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(fd > 0); HWY_ASSERT(fd > 0);
const Path path(path_str); const Path path(path_str);
BlobWriter writer; BlobWriter writer(path, pool);
std::vector<std::string> keys; std::vector<std::string> keys;
keys.reserve(num_blobs); keys.reserve(num_blobs);
std::vector<std::vector<uint8_t>> blobs; std::vector<std::vector<uint8_t>> blobs;
@ -126,7 +126,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
} }
HWY_ASSERT(keys.size() == num_blobs); HWY_ASSERT(keys.size() == num_blobs);
HWY_ASSERT(blobs.size() == num_blobs); HWY_ASSERT(blobs.size() == num_blobs);
writer.WriteAll(pool, path); writer.WriteAll();
BlobReader reader(path); BlobReader reader(path);
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs); HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);