diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 7b5bc15..2de1b67 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -75,8 +75,8 @@ class SbsWriterImpl : public ISbsWriter { } mat.AppendTo(serialized_mat_ptrs_); - mat_owners_.push_back(MatOwner()); - mat_owners_.back().AllocateFor(mat, ctx_.allocator, MatPadding::kPacked); + MatOwner mat_owner; + mat_owner.AllocateFor(mat, ctx_.allocator, MatPadding::kPacked); // Handle gemma_export_test's MockArray. Write blobs so that the test // succeeds, but we only have 10 floats, not the full tensor. @@ -97,7 +97,9 @@ class SbsWriterImpl : public ISbsWriter { } public: - SbsWriterImpl() : ctx_(ThreadingArgs()) {} + SbsWriterImpl(const std::string& sbs_path) + : ctx_(ThreadingArgs()), + writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {} void Insert(const char* name, F32Span weights, Type type, const TensorInfo& tensor_info) override { @@ -120,23 +122,23 @@ class SbsWriterImpl : public ISbsWriter { } } - void Write(const ModelConfig& config, const std::string& tokenizer_path, - const std::string& path) override { + void Write(const ModelConfig& config, + const std::string& tokenizer_path) override { const GemmaTokenizer tokenizer( tokenizer_path.empty() ? kMockTokenizer : ReadFileToString(Path(tokenizer_path))); - WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_, - ctx_.pools.Pool(), gcpp::Path(path)); + WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_); } ThreadingContext ctx_; - std::vector mat_owners_; CompressWorkingSet working_set_; BlobWriter writer_; std::vector serialized_mat_ptrs_; }; -ISbsWriter* NewSbsWriter() { return new SbsWriterImpl(); } +ISbsWriter* NewSbsWriter(const std::string& sbs_path) { + return new SbsWriterImpl(sbs_path); +} } // namespace HWY_NAMESPACE } // namespace gcpp @@ -147,7 +149,8 @@ namespace gcpp { HWY_EXPORT(NewSbsWriter); -SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {} +SbsWriter::SbsWriter(const std::string& path) + : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {} SbsReader::SbsReader(const std::string& path) : reader_(Path(path)), model_(reader_) {} diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 19df192..6979865 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -44,24 +44,22 @@ class ISbsWriter { const TensorInfo& tensor_info) = 0; virtual void Write(const ModelConfig& config, - const std::string& tokenizer_path, - const std::string& path) = 0; + const std::string& tokenizer_path) = 0; }; // Non-virtual class used by pybind that calls the interface's virtual methods. // This avoids having to register the derived types with pybind. class SbsWriter { public: - SbsWriter(); + explicit SbsWriter(const std::string& sbs_path); void Insert(const char* name, F32Span weights, Type type, const TensorInfo& tensor_info) { impl_->Insert(name, weights, type, tensor_info); } - void Write(const ModelConfig& config, const std::string& tokenizer_path, - const std::string& path) { - impl_->Write(config, tokenizer_path, path); + void Write(const ModelConfig& config, const std::string& tokenizer_path) { + impl_->Write(config, tokenizer_path); } private: diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc index 8edbfd5..e3d1556 100644 --- a/compression/python/compression_extension.cc +++ b/compression/python/compression_extension.cc @@ -44,10 +44,9 @@ static void CallWithF32Span(SbsWriter& writer, const char* name, PYBIND11_MODULE(compression, m) { class_(m, "SbsWriter") - .def(init<>()) + .def(init()) .def("insert", CallWithF32Span<&SbsWriter::Insert>) - .def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"), - arg("path")); + .def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path")); class_(m, "MatPtr") // No init, only created within C++. diff --git a/compression/python/compression_test.py b/compression/python/compression_test.py index 2ed0916..e8244ed 100644 --- a/compression/python/compression_test.py +++ b/compression/python/compression_test.py @@ -30,7 +30,8 @@ class CompressionTest(absltest.TestCase): info_192.axes = [0] info_192.shape = [192] - writer = compression.SbsWriter() + temp_file = self.create_tempfile("test.sbs") + writer = compression.SbsWriter(temp_file.full_path) writer.insert( "tensor0", # Large enough to require scaling. @@ -95,8 +96,7 @@ class CompressionTest(absltest.TestCase): configs.PromptWrapping.GEMMA_IT, ) tokenizer_path = "" # no tokenizer required for testing - temp_file = self.create_tempfile("test.sbs") - writer.write(config, tokenizer_path, temp_file.full_path) + writer.write(config, tokenizer_path) print("Ignore next two warnings; test does not enable model deduction.") reader = compression.SbsReader(temp_file.full_path) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 19d9926..496c21d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -618,11 +618,11 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma::~Gemma() = default; void Gemma::Save(const Path& weights_path, NestedPools& pools) const { - BlobWriter writer; + BlobWriter writer(weights_path, pools.Pool()); const std::vector serialized_mat_ptrs = weights_.AddTensorDataToWriter(writer); WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, - writer, pools.Pool(), weights_path); + writer); } void Gemma::Generate(const RuntimeConfig& runtime_config, diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 0ffdedb..8f6c138 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -444,8 +444,7 @@ static void AddBlob(const char* name, const std::vector& data, void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, const std::vector& serialized_mat_ptrs, - BlobWriter& writer, hwy::ThreadPool& pool, - const Path& path) { + BlobWriter& writer) { HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.weight != Type::kUnknown); HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); @@ -459,7 +458,7 @@ void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, AddBlob(kMatPtrsName, serialized_mat_ptrs, writer); - writer.WriteAll(pool, path); + writer.WriteAll(); } } // namespace gcpp diff --git a/gemma/model_store.h b/gemma/model_store.h index 0c0803a..42af343 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -105,8 +105,7 @@ class ModelStore { // produces a single BlobStore file holding everything required for inference. void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, const std::vector& serialized_mat_ptrs, - BlobWriter& writer, hwy::ThreadPool& pool, - const Path& path); + BlobWriter& writer); } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_ diff --git a/io/blob_store.cc b/io/blob_store.cc index 6b33186..c3fbea1 100644 --- a/io/blob_store.cc +++ b/io/blob_store.cc @@ -86,16 +86,26 @@ static_assert(sizeof(Header) == 16); // A write I/O request, each serviced by one thread in a pool. struct BlobIO { - BlobIO(BlobRange range, void* data) : range(range), data(data) {} + BlobIO(BlobRange range, const void* data) : range(range), data(data) {} BlobRange range; - void* data; // Read-only for writes. + const void* data; // Read-only for writes. }; -// Little-endian on-disk representation: a fixed-size `Header`, then a padded -// variable-length 'directory' of blob keys and their offset/sizes, then the -// 'payload' of each blob's data with padding in between, followed by padding to -// `kEndAlign`. Keys are unique, opaque 128-bit keys. +// Little-endian on-disk representation: +// For V1: the file is represented as +// Header + Directory + PadToBlobAlign + Payload + PayToEndAlign. +// For V2: the file is represented as +// Header + PadToBlobAlign + Payload + PadToEndAlign + Directory + Header +// The Header at the beginning has num_blobs == 0; and the Header at the end has +// the correct num_blobs. +// +// Actual payload is indexed by the directory with keys, offset and bytes; keys +// are unique, opaque 128-bit keys. +// +// The file format deliberately omits a version number because it is unchanging. +// Additional data may be added only inside new blobs. Changes to the blob +// contents or type should be handled by renaming keys. // // The file format deliberately omits a version number because it is unchanging. // Additional data may be added only inside new blobs. Changes to the blob @@ -106,23 +116,22 @@ struct BlobIO { class BlobStore { static constexpr uint32_t kMagic = 0x0A534253; // SBS\n - // Arbitrary upper limit to avoid allocating a huge vector. - static constexpr size_t kMaxBlobs = 64 * 1024; + // Upper limit to avoid allocating a huge vector. + static constexpr size_t kMaxBlobs = 16 * 1024; - // Returns the end of the directory, including padding, which is also the - // start of the first payload. `num_blobs` is `NumBlobs()` if the header is + // Returns the size of padded header and directory, which is also the start of + // the first payload for V1. `num_blobs` is `NumBlobs()` if the header is // already available, otherwise the number of blobs to be written. - static size_t PaddedDirEnd(size_t num_blobs) { + static size_t PaddedHeaderAndDirBytes(size_t num_blobs) { HWY_ASSERT(num_blobs < kMaxBlobs); // Per blob, a key and offset/size. return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs); } - static uint64_t PaddedPayloadBytes(size_t num_blobs, - const hwy::Span blobs[]) { + static uint64_t PaddedPayloadBytes(const std::vector& blob_sizes) { uint64_t total_payload_bytes = 0; - for (size_t i = 0; i < num_blobs; ++i) { - total_payload_bytes += RoundUpToAlign(blobs[i].size()); + for (size_t blob_size : blob_sizes) { + total_payload_bytes += RoundUpToAlign(blob_size); } // Do not round up to `kEndAlign` because the padding also depends on the // directory size. Here we only count the payload. @@ -136,6 +145,72 @@ class BlobStore { } } + bool ParseHeaderAndDirectoryV1(const File& file) { + is_file_v2_ = false; + // Read header from the beginning of the file. + if (!file.Read(0, sizeof(header_), &header_)) { + HWY_WARN("Failed to read BlobStore header."); + return false; + } + + if (header_.magic != kMagic) { + HWY_WARN("BlobStore header magic %08x does not match %08x.", + header_.magic, kMagic); + return false; + } + + if (header_.num_blobs == 0) { + // Should parse as V2. + return false; + } + + if (header_.num_blobs > kMaxBlobs) { + HWY_WARN("Too many blobs, likely corrupt file."); + return false; + } + + directory_.resize(header_.num_blobs * 2); + const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs; + // Read directory after the header. + if (!file.Read(sizeof(header_), directory_bytes, directory_.data())) { + HWY_WARN("Failed to read BlobStore directory."); + return false; + } + HWY_ASSERT(IsValid(file.FileSize())); + return true; + } + + bool ParseHeaderAndDirectoryV2(const File& file) { + is_file_v2_ = true; + // Read header from the end of the file. + size_t offset = file.FileSize() - sizeof(header_); + if (!file.Read(offset, sizeof(header_), &header_)) { + HWY_WARN("Failed to read BlobStore header."); + return false; + } + + if (header_.magic != kMagic) { + HWY_WARN("BlobStore header magic %08x does not match %08x.", + header_.magic, kMagic); + return false; + } + + if (header_.num_blobs > kMaxBlobs) { + HWY_WARN("Too many blobs, likely corrupt file."); + return false; + } + directory_.resize(header_.num_blobs * 2); + const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs; + offset -= directory_bytes; + // Read directory immediately before the header. + if (!file.Read(offset, directory_bytes, directory_.data())) { + HWY_WARN("Failed to read BlobStore directory."); + return false; + } + HWY_ASSERT(IsValid(file.FileSize())); + return true; + } + public: template static T RoundUpToAlign(T size_or_offset) { @@ -144,60 +219,46 @@ class BlobStore { // Reads header/directory from file. explicit BlobStore(const File& file) { - if (!file.Read(0, sizeof(header_), &header_)) { - HWY_WARN("Failed to read BlobStore header."); + if (ParseHeaderAndDirectoryV1(file)) { return; } - // Avoid allocating a huge vector. - if (header_.num_blobs >= kMaxBlobs) { - HWY_WARN("Too many blobs, likely corrupt file."); - return; - } - - const size_t padded_dir_end = PaddedDirEnd(NumBlobs()); - const size_t padded_dir_bytes = padded_dir_end - sizeof(header_); - HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0); - directory_.resize(padded_dir_bytes / kU128Bytes); - if (!file.Read(sizeof(header_), padded_dir_bytes, directory_.data())) { - HWY_WARN("Failed to read BlobStore directory."); + if (ParseHeaderAndDirectoryV2(file)) { return; } + HWY_ABORT("Failed to read BlobStore header or directory."); } // Initializes header/directory for writing to disk. - BlobStore(size_t num_blobs, const hwy::uint128_t keys[], - const hwy::Span blobs[]) { + BlobStore(const std::vector& keys, + const std::vector& blob_sizes) { + const size_t num_blobs = keys.size(); HWY_ASSERT(num_blobs < kMaxBlobs); // Ensures safe to cast to u32. - HWY_ASSERT(keys && blobs); - EnsureUnique(hwy::Span(keys, num_blobs)); - - uint64_t offset = PaddedDirEnd(num_blobs); - const size_t padded_dir_bytes = - static_cast(offset) - sizeof(header_); + HWY_ASSERT(keys.size() == blob_sizes.size()); + EnsureUnique(hwy::Span(keys.data(), num_blobs)); + // Set header_. header_.magic = kMagic; header_.num_blobs = static_cast(num_blobs); - header_.file_bytes = hwy::RoundUpTo( - offset + PaddedPayloadBytes(num_blobs, blobs), kEndAlign); - HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0); - directory_.resize(padded_dir_bytes / kU128Bytes); - hwy::CopyBytes(keys, directory_.data(), num_blobs * kU128Bytes); + const size_t size_before_blobs = BytesBeforeBlobsV2().size(); + header_.file_bytes = + hwy::RoundUpTo(size_before_blobs + PaddedPayloadBytes(blob_sizes) + + PaddedHeaderAndDirBytes(num_blobs), + kEndAlign); + + // Set first num_blobs elements of directory_ which are the keys. + directory_.resize(num_blobs * 2); + hwy::CopyBytes(keys.data(), directory_.data(), num_blobs * kU128Bytes); EnsureUnique(Keys()); - // `SetRange` below will fill `directory_[num_blobs, 2 * num_blobs)`. - hwy::ZeroBytes(directory_.data() + 2 * num_blobs, - padded_dir_bytes - 2 * num_blobs * kU128Bytes); - // We already zero-initialized the directory padding; - // `BlobWriter::WriteAll` takes care of padding after each blob via an - // additional I/O. + // Set the second half of directory_ which is the offsets and sizes. + uint64_t offset = size_before_blobs; for (size_t i = 0; i < num_blobs; ++i) { - HWY_ASSERT(blobs[i].data() != nullptr); - SetRange(i, offset, blobs[i].size()); - offset = RoundUpToAlign(offset + blobs[i].size()); + SetRange(i, offset, blob_sizes[i]); + offset = RoundUpToAlign(offset + blob_sizes[i]); } - // When writing new files, we always pad to `kEndAlign`. - HWY_ASSERT(hwy::RoundUpTo(offset, kEndAlign) == header_.file_bytes); + + HWY_ASSERT(IsValid(FileSize())); } // Must be checked by readers before other methods. @@ -221,7 +282,10 @@ class BlobStore { } // Ensure blobs are back to back. - uint64_t expected_offset = PaddedDirEnd(NumBlobs()); + const size_t size_before_blobs = BytesBeforeBlobs().size(); + const size_t size_after_blobs = BytesAfterBlobs().size(); + + uint64_t expected_offset = size_before_blobs; for (size_t key_idx = 0; key_idx < NumBlobs(); ++key_idx) { uint64_t actual_offset; size_t bytes; @@ -236,7 +300,7 @@ class BlobStore { } // Previously files were not padded to `kEndAlign`, so also allow that. if (expected_offset != header_.file_bytes && - hwy::RoundUpTo(expected_offset, kEndAlign) != header_.file_bytes) { + expected_offset + size_after_blobs != header_.file_bytes) { HWY_WARN("Invalid BlobStore: end of blobs %zu but file size %zu.", static_cast(expected_offset), static_cast(header_.file_bytes)); @@ -246,20 +310,71 @@ class BlobStore { return true; // all OK } - void EnqueueWriteForHeaderAndDirectory(std::vector& writes) const { - const size_t key_idx = 0; // not actually associated with a key/blob - writes.emplace_back( - BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx}, - // members are const and BlobIO requires non-const pointers, and they - // are not modified by file writes. - const_cast(&header_)); - writes.emplace_back( - BlobRange{.offset = sizeof(header_), - .bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_), - .key_idx = key_idx}, - const_cast(directory_.data())); + static std::vector BytesBeforeBlobsV2() { + const Header kFakeHeaderV2 = { + .magic = kMagic, + .num_blobs = 0, + .file_bytes = kEndAlign, + }; + std::vector header(PaddedHeaderAndDirBytes(0)); + hwy::CopyBytes(&kFakeHeaderV2, header.data(), sizeof(Header)); + return header; } + std::vector BytesBeforeBlobs() const { + if (is_file_v2_) { + return BytesBeforeBlobsV2(); + } else { + const size_t padded_header_and_directory_size = + PaddedHeaderAndDirBytes(NumBlobs()); + std::vector header_and_directory( + padded_header_and_directory_size); + + // Copy header_ at the beginning (offset 0) + hwy::CopyBytes(&header_, header_and_directory.data(), sizeof(header_)); + + // Copy directory_ immediately after the header_ + hwy::CopyBytes(directory_.data(), + header_and_directory.data() + sizeof(header_), + 2 * kU128Bytes * NumBlobs()); + return header_and_directory; + } + } + + std::vector BytesAfterBlobs() const { + // Gets blob end. + uint64_t offset = 0; + size_t bytes = 0; + GetRange(NumBlobs() - 1, offset, bytes); + const uint64_t blob_end = RoundUpToAlign(offset + bytes); + + // For V1, just return the file paddings. + if (!is_file_v2_) { + return std::vector(FileSize() - blob_end); + } + + const size_t header_and_directory_with_file_padding_size = + FileSize() - blob_end; + std::vector header_and_directory( + header_and_directory_with_file_padding_size); + + const size_t header_size = sizeof(Header); + const size_t directory_size = 2 * kU128Bytes * NumBlobs(); + + // Copy header_ at the end. + offset = header_and_directory_with_file_padding_size - header_size; + hwy::CopyBytes(&header_, header_and_directory.data() + offset, header_size); + + // Copy directory_ immediately before the header_. + offset -= directory_size; + hwy::CopyBytes(directory_.data(), header_and_directory.data() + offset, + directory_size); + + return header_and_directory; + } + + size_t FileSize() const { return header_.file_bytes; } + size_t NumBlobs() const { return static_cast(header_.num_blobs); } // Not the entirety of `directory_`! The second half is offset/size. @@ -291,6 +406,8 @@ class BlobStore { val.hi = bytes; } + bool is_file_v2_ = true; + Header header_; std::vector directory_; // two per blob, see `SetRange`. @@ -303,7 +420,6 @@ BlobReader::BlobReader(const Path& blob_path) if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str()); BlobStore bs(*file_); - HWY_ASSERT(bs.IsValid(file_bytes_)); // IsValid already printed a warning keys_.reserve(bs.NumBlobs()); for (const hwy::uint128_t key : bs.Keys()) { @@ -324,8 +440,8 @@ BlobReader::BlobReader(const Path& blob_path) // Split into chunks for load-balancing even if blob sizes vary. static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, - uint8_t* data, std::vector& writes) { - constexpr size_t kChunkBytes = 4 * 1024 * 1024; + const uint8_t* data, std::vector& writes) { + constexpr size_t kChunkBytes = 10 * 1024 * 1024; const uint64_t end = offset + bytes; // Split into whole chunks and possibly one remainder. if (end >= kChunkBytes) { @@ -341,83 +457,42 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes, BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx}, data); } + + // Write a padding if necessary. + static constexpr uint8_t kZeros[kBlobAlign] = {0}; + const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes; + if (padding > 0) { + writes.emplace_back( + BlobRange{.offset = end, .bytes = padding, .key_idx = key_idx}, + static_cast(kZeros)); + } } -static void EnqueueWritesForBlobs(const BlobStore& bs, - const hwy::Span blobs[], - std::vector& zeros, - std::vector& writes) { - // All-zero buffer used to write padding to the file without copying the - // input blobs. - static constexpr uint8_t kZeros[kBlobAlign] = {0}; - - uint64_t file_end = 0; // for padding - for (size_t key_idx = 0; key_idx < bs.NumBlobs(); ++key_idx) { - // We know the size, but `BlobStore` tells us the offset to write each blob. - uint64_t offset; - size_t bytes; - bs.GetRange(key_idx, offset, bytes); - HWY_ASSERT(offset != 0); - HWY_ASSERT(bytes == blobs[key_idx].size()); - const uint64_t new_file_end = offset + bytes; - HWY_ASSERT(new_file_end >= file_end); // blobs are ordered by offset - file_end = new_file_end; - - EnqueueChunks(key_idx, offset, bytes, - const_cast(blobs[key_idx].data()), writes); - const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes; - if (padding != 0) { - HWY_ASSERT(padding <= kBlobAlign); - writes.emplace_back( - BlobRange{ - .offset = offset + bytes, .bytes = padding, .key_idx = key_idx}, - const_cast(kZeros)); - } - } - - const size_t padding = hwy::RoundUpTo(file_end, kEndAlign) - file_end; - if (padding != 0) { - // Bigger than `kZeros`, better to allocate than issue multiple I/Os. Must - // remain alive until the last I/O is done. - zeros.resize(padding); - writes.emplace_back( - BlobRange{.offset = file_end, .bytes = padding, .key_idx = 0}, - zeros.data()); - } +BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool) + : file_(OpenFileOrNull(filename, "w+")), pool_(pool) { + if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str()); + // Write a fake header to the beginning of the file. + std::vector bytes_before_blobs = BlobStore::BytesBeforeBlobsV2(); + file_->Write(bytes_before_blobs.data(), bytes_before_blobs.size(), 0); } void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) { HWY_ASSERT(data != nullptr); HWY_ASSERT(bytes != 0); keys_.push_back(KeyFromString(key.c_str())); - blobs_.emplace_back(static_cast(data), bytes); -} - -void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { - const size_t num_blobs = keys_.size(); - HWY_ASSERT(num_blobs != 0); - HWY_ASSERT(num_blobs == blobs_.size()); + blob_sizes_.push_back(bytes); std::vector writes; - writes.reserve(16384); - - const BlobStore bs(num_blobs, keys_.data(), blobs_.data()); - bs.EnqueueWriteForHeaderAndDirectory(writes); - - std::vector zeros; - EnqueueWritesForBlobs(bs, blobs_.data(), zeros, writes); - - // Create/replace existing file. - std::unique_ptr file = OpenFileOrNull(filename, "w+"); - if (!file) HWY_ABORT("Failed to open for writing %s", filename.path.c_str()); + EnqueueChunks(keys_.size() - 1, file_->FileSize(), bytes, + static_cast(data), writes); hwy::ThreadPool null_pool(0); - hwy::ThreadPool& pool_or_serial = file->IsAppendOnly() ? null_pool : pool; + hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_; pool_or_serial.Run( - 0, writes.size(), [this, &file, &writes](uint64_t i, size_t /*thread*/) { + 0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) { const BlobRange& range = writes[i].range; - if (!file->Write(writes[i].data, range.bytes, range.offset)) { + if (!file_->Write(writes[i].data, range.bytes, range.offset)) { const std::string& key = StringFromKey(keys_[range.key_idx]); HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.", key.c_str(), static_cast(range.offset), range.bytes, @@ -426,4 +501,13 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { }); } +void BlobWriter::WriteAll() { + const BlobStore bs = BlobStore(keys_, blob_sizes_); + + // Write the rest of the bytes, which contains: paddings + directory + header. + const auto bytes_after_blobs = bs.BytesAfterBlobs(); + file_->Write(bytes_after_blobs.data(), bytes_after_blobs.size(), + file_->FileSize()); +} + } // namespace gcpp diff --git a/io/blob_store.h b/io/blob_store.h index aa28210..f7103d7 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -116,6 +116,8 @@ class BlobReader { // does not make sense to call the methods concurrently. class BlobWriter { public: + explicit BlobWriter(const Path& filename, hwy::ThreadPool& pool); + void Add(const std::string& key, const void* data, size_t bytes); // For `ModelStore`: this is the `key_idx` of the next blob to be added. @@ -123,11 +125,13 @@ class BlobWriter { // Stores all blobs to disk in the given order with padding for alignment. // Aborts on error. - void WriteAll(hwy::ThreadPool& pool, const Path& filename); + void WriteAll(); private: + std::unique_ptr file_; std::vector keys_; - std::vector> blobs_; + std::vector blob_sizes_; + hwy::ThreadPool& pool_; }; } // namespace gcpp diff --git a/io/blob_store_test.cc b/io/blob_store_test.cc index b763428..36ba27f 100644 --- a/io/blob_store_test.cc +++ b/io/blob_store_test.cc @@ -52,10 +52,10 @@ TEST(BlobStoreTest, TestReadWrite) { const std::string keyA("0123456789abcdef"); // max 16 characters const std::string keyB("q"); - BlobWriter writer; + BlobWriter writer(path, pool); writer.Add(keyA, "DATA", 5); writer.Add(keyB, buffer.data(), sizeof(buffer)); - writer.WriteAll(pool, path); + writer.WriteAll(); HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size()); std::fill(buffer.begin(), buffer.end(), 0); @@ -69,12 +69,12 @@ TEST(BlobStoreTest, TestReadWrite) { const BlobRange* range = reader.Find(keyA); HWY_ASSERT(range); const uint64_t offsetA = range->offset; - HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign + HWY_ASSERT_EQ(offsetA, 256); HWY_ASSERT_EQ(range->bytes, 5); range = reader.Find(keyB); HWY_ASSERT(range); const uint64_t offsetB = range->offset; - HWY_ASSERT_EQ(offsetB, 2 * 256); + HWY_ASSERT_EQ(offsetB, offsetA + 256); HWY_ASSERT_EQ(range->bytes, sizeof(buffer)); HWY_ASSERT( @@ -106,7 +106,7 @@ TEST(BlobStoreTest, TestNumBlobs) { HWY_ASSERT(fd > 0); const Path path(path_str); - BlobWriter writer; + BlobWriter writer(path, pool); std::vector keys; keys.reserve(num_blobs); std::vector> blobs; @@ -126,7 +126,7 @@ TEST(BlobStoreTest, TestNumBlobs) { } HWY_ASSERT(keys.size() == num_blobs); HWY_ASSERT(blobs.size() == num_blobs); - writer.WriteAll(pool, path); + writer.WriteAll(); BlobReader reader(path); HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);