Write SBS progressively.

(1) Directly write to file in BlobWriter::Add and destruct the MatOwner to release the rams.

(2) Write a fake header to indicate this is V2, and write correct header and directory at the end of the file.

(3) Tested on loading sbs written the old way, and new way, both worked.

PiperOrigin-RevId: 789306837
This commit is contained in:
Charles Zhao 2025-07-31 06:05:02 -07:00 committed by Copybara-Service
parent 8715eda512
commit 50ee1a3e92
10 changed files with 255 additions and 169 deletions

View File

@ -75,8 +75,8 @@ class SbsWriterImpl : public ISbsWriter {
}
mat.AppendTo(serialized_mat_ptrs_);
mat_owners_.push_back(MatOwner());
mat_owners_.back().AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
MatOwner mat_owner;
mat_owner.AllocateFor(mat, ctx_.allocator, MatPadding::kPacked);
// Handle gemma_export_test's MockArray. Write blobs so that the test
// succeeds, but we only have 10 floats, not the full tensor.
@ -97,7 +97,9 @@ class SbsWriterImpl : public ISbsWriter {
}
public:
SbsWriterImpl() : ctx_(ThreadingArgs()) {}
SbsWriterImpl(const std::string& sbs_path)
: ctx_(ThreadingArgs()),
writer_(gcpp::Path(sbs_path), ctx_.pools.Pool()) {}
void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override {
@ -120,23 +122,23 @@ class SbsWriterImpl : public ISbsWriter {
}
}
void Write(const ModelConfig& config, const std::string& tokenizer_path,
const std::string& path) override {
void Write(const ModelConfig& config,
const std::string& tokenizer_path) override {
const GemmaTokenizer tokenizer(
tokenizer_path.empty() ? kMockTokenizer
: ReadFileToString(Path(tokenizer_path)));
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_,
ctx_.pools.Pool(), gcpp::Path(path));
WriteSingleFile(config, tokenizer, serialized_mat_ptrs_, writer_);
}
ThreadingContext ctx_;
std::vector<MatOwner> mat_owners_;
CompressWorkingSet working_set_;
BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_;
};
ISbsWriter* NewSbsWriter() { return new SbsWriterImpl(); }
ISbsWriter* NewSbsWriter(const std::string& sbs_path) {
return new SbsWriterImpl(sbs_path);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
@ -147,7 +149,8 @@ namespace gcpp {
HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsWriter::SbsWriter(const std::string& path)
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {}
SbsReader::SbsReader(const std::string& path)
: reader_(Path(path)), model_(reader_) {}

View File

@ -44,24 +44,22 @@ class ISbsWriter {
const TensorInfo& tensor_info) = 0;
virtual void Write(const ModelConfig& config,
const std::string& tokenizer_path,
const std::string& path) = 0;
const std::string& tokenizer_path) = 0;
};
// Non-virtual class used by pybind that calls the interface's virtual methods.
// This avoids having to register the derived types with pybind.
class SbsWriter {
public:
SbsWriter();
explicit SbsWriter(const std::string& sbs_path);
void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) {
impl_->Insert(name, weights, type, tensor_info);
}
void Write(const ModelConfig& config, const std::string& tokenizer_path,
const std::string& path) {
impl_->Write(config, tokenizer_path, path);
void Write(const ModelConfig& config, const std::string& tokenizer_path) {
impl_->Write(config, tokenizer_path);
}
private:

View File

@ -44,10 +44,9 @@ static void CallWithF32Span(SbsWriter& writer, const char* name,
PYBIND11_MODULE(compression, m) {
class_<SbsWriter>(m, "SbsWriter")
.def(init<>())
.def(init<std::string>())
.def("insert", CallWithF32Span<&SbsWriter::Insert>)
.def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"),
arg("path"));
.def("write", &SbsWriter::Write, arg("config"), arg("tokenizer_path"));
class_<MatPtr>(m, "MatPtr")
// No init, only created within C++.

View File

@ -30,7 +30,8 @@ class CompressionTest(absltest.TestCase):
info_192.axes = [0]
info_192.shape = [192]
writer = compression.SbsWriter()
temp_file = self.create_tempfile("test.sbs")
writer = compression.SbsWriter(temp_file.full_path)
writer.insert(
"tensor0",
# Large enough to require scaling.
@ -95,8 +96,7 @@ class CompressionTest(absltest.TestCase):
configs.PromptWrapping.GEMMA_IT,
)
tokenizer_path = "" # no tokenizer required for testing
temp_file = self.create_tempfile("test.sbs")
writer.write(config, tokenizer_path, temp_file.full_path)
writer.write(config, tokenizer_path)
print("Ignore next two warnings; test does not enable model deduction.")
reader = compression.SbsReader(temp_file.full_path)

View File

@ -618,11 +618,11 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, NestedPools& pools) const {
BlobWriter writer;
BlobWriter writer(weights_path, pools.Pool());
const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
writer, pools.Pool(), weights_path);
writer);
}
void Gemma::Generate(const RuntimeConfig& runtime_config,

View File

@ -444,8 +444,7 @@ static void AddBlob(const char* name, const std::vector<uint32_t>& data,
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path) {
BlobWriter& writer) {
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
@ -459,7 +458,7 @@ void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
AddBlob(kMatPtrsName, serialized_mat_ptrs, writer);
writer.WriteAll(pool, path);
writer.WriteAll();
}
} // namespace gcpp

View File

@ -105,8 +105,7 @@ class ModelStore {
// produces a single BlobStore file holding everything required for inference.
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path);
BlobWriter& writer);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_MODEL_STORE_H_

View File

@ -86,16 +86,26 @@ static_assert(sizeof(Header) == 16);
// A write I/O request, each serviced by one thread in a pool.
struct BlobIO {
BlobIO(BlobRange range, void* data) : range(range), data(data) {}
BlobIO(BlobRange range, const void* data) : range(range), data(data) {}
BlobRange range;
void* data; // Read-only for writes.
const void* data; // Read-only for writes.
};
// Little-endian on-disk representation: a fixed-size `Header`, then a padded
// variable-length 'directory' of blob keys and their offset/sizes, then the
// 'payload' of each blob's data with padding in between, followed by padding to
// `kEndAlign`. Keys are unique, opaque 128-bit keys.
// Little-endian on-disk representation:
// For V1: the file is represented as
// Header + Directory + PadToBlobAlign + Payload + PayToEndAlign.
// For V2: the file is represented as
// Header + PadToBlobAlign + Payload + PadToEndAlign + Directory + Header
// The Header at the beginning has num_blobs == 0; and the Header at the end has
// the correct num_blobs.
//
// Actual payload is indexed by the directory with keys, offset and bytes; keys
// are unique, opaque 128-bit keys.
//
// The file format deliberately omits a version number because it is unchanging.
// Additional data may be added only inside new blobs. Changes to the blob
// contents or type should be handled by renaming keys.
//
// The file format deliberately omits a version number because it is unchanging.
// Additional data may be added only inside new blobs. Changes to the blob
@ -106,23 +116,22 @@ struct BlobIO {
class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
// Arbitrary upper limit to avoid allocating a huge vector.
static constexpr size_t kMaxBlobs = 64 * 1024;
// Upper limit to avoid allocating a huge vector.
static constexpr size_t kMaxBlobs = 16 * 1024;
// Returns the end of the directory, including padding, which is also the
// start of the first payload. `num_blobs` is `NumBlobs()` if the header is
// Returns the size of padded header and directory, which is also the start of
// the first payload for V1. `num_blobs` is `NumBlobs()` if the header is
// already available, otherwise the number of blobs to be written.
static size_t PaddedDirEnd(size_t num_blobs) {
static size_t PaddedHeaderAndDirBytes(size_t num_blobs) {
HWY_ASSERT(num_blobs < kMaxBlobs);
// Per blob, a key and offset/size.
return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs);
}
static uint64_t PaddedPayloadBytes(size_t num_blobs,
const hwy::Span<const uint8_t> blobs[]) {
static uint64_t PaddedPayloadBytes(const std::vector<size_t>& blob_sizes) {
uint64_t total_payload_bytes = 0;
for (size_t i = 0; i < num_blobs; ++i) {
total_payload_bytes += RoundUpToAlign(blobs[i].size());
for (size_t blob_size : blob_sizes) {
total_payload_bytes += RoundUpToAlign(blob_size);
}
// Do not round up to `kEndAlign` because the padding also depends on the
// directory size. Here we only count the payload.
@ -136,6 +145,72 @@ class BlobStore {
}
}
bool ParseHeaderAndDirectoryV1(const File& file) {
is_file_v2_ = false;
// Read header from the beginning of the file.
if (!file.Read(0, sizeof(header_), &header_)) {
HWY_WARN("Failed to read BlobStore header.");
return false;
}
if (header_.magic != kMagic) {
HWY_WARN("BlobStore header magic %08x does not match %08x.",
header_.magic, kMagic);
return false;
}
if (header_.num_blobs == 0) {
// Should parse as V2.
return false;
}
if (header_.num_blobs > kMaxBlobs) {
HWY_WARN("Too many blobs, likely corrupt file.");
return false;
}
directory_.resize(header_.num_blobs * 2);
const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs;
// Read directory after the header.
if (!file.Read(sizeof(header_), directory_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
return false;
}
HWY_ASSERT(IsValid(file.FileSize()));
return true;
}
bool ParseHeaderAndDirectoryV2(const File& file) {
is_file_v2_ = true;
// Read header from the end of the file.
size_t offset = file.FileSize() - sizeof(header_);
if (!file.Read(offset, sizeof(header_), &header_)) {
HWY_WARN("Failed to read BlobStore header.");
return false;
}
if (header_.magic != kMagic) {
HWY_WARN("BlobStore header magic %08x does not match %08x.",
header_.magic, kMagic);
return false;
}
if (header_.num_blobs > kMaxBlobs) {
HWY_WARN("Too many blobs, likely corrupt file.");
return false;
}
directory_.resize(header_.num_blobs * 2);
const auto directory_bytes = 2 * kU128Bytes * header_.num_blobs;
offset -= directory_bytes;
// Read directory immediately before the header.
if (!file.Read(offset, directory_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
return false;
}
HWY_ASSERT(IsValid(file.FileSize()));
return true;
}
public:
template <typename T>
static T RoundUpToAlign(T size_or_offset) {
@ -144,60 +219,46 @@ class BlobStore {
// Reads header/directory from file.
explicit BlobStore(const File& file) {
if (!file.Read(0, sizeof(header_), &header_)) {
HWY_WARN("Failed to read BlobStore header.");
if (ParseHeaderAndDirectoryV1(file)) {
return;
}
// Avoid allocating a huge vector.
if (header_.num_blobs >= kMaxBlobs) {
HWY_WARN("Too many blobs, likely corrupt file.");
return;
}
const size_t padded_dir_end = PaddedDirEnd(NumBlobs());
const size_t padded_dir_bytes = padded_dir_end - sizeof(header_);
HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0);
directory_.resize(padded_dir_bytes / kU128Bytes);
if (!file.Read(sizeof(header_), padded_dir_bytes, directory_.data())) {
HWY_WARN("Failed to read BlobStore directory.");
if (ParseHeaderAndDirectoryV2(file)) {
return;
}
HWY_ABORT("Failed to read BlobStore header or directory.");
}
// Initializes header/directory for writing to disk.
BlobStore(size_t num_blobs, const hwy::uint128_t keys[],
const hwy::Span<const uint8_t> blobs[]) {
BlobStore(const std::vector<hwy::uint128_t>& keys,
const std::vector<size_t>& blob_sizes) {
const size_t num_blobs = keys.size();
HWY_ASSERT(num_blobs < kMaxBlobs); // Ensures safe to cast to u32.
HWY_ASSERT(keys && blobs);
EnsureUnique(hwy::Span<const hwy::uint128_t>(keys, num_blobs));
uint64_t offset = PaddedDirEnd(num_blobs);
const size_t padded_dir_bytes =
static_cast<size_t>(offset) - sizeof(header_);
HWY_ASSERT(keys.size() == blob_sizes.size());
EnsureUnique(hwy::Span<const hwy::uint128_t>(keys.data(), num_blobs));
// Set header_.
header_.magic = kMagic;
header_.num_blobs = static_cast<uint32_t>(num_blobs);
header_.file_bytes = hwy::RoundUpTo(
offset + PaddedPayloadBytes(num_blobs, blobs), kEndAlign);
HWY_ASSERT(padded_dir_bytes % kU128Bytes == 0);
directory_.resize(padded_dir_bytes / kU128Bytes);
hwy::CopyBytes(keys, directory_.data(), num_blobs * kU128Bytes);
const size_t size_before_blobs = BytesBeforeBlobsV2().size();
header_.file_bytes =
hwy::RoundUpTo(size_before_blobs + PaddedPayloadBytes(blob_sizes) +
PaddedHeaderAndDirBytes(num_blobs),
kEndAlign);
// Set first num_blobs elements of directory_ which are the keys.
directory_.resize(num_blobs * 2);
hwy::CopyBytes(keys.data(), directory_.data(), num_blobs * kU128Bytes);
EnsureUnique(Keys());
// `SetRange` below will fill `directory_[num_blobs, 2 * num_blobs)`.
hwy::ZeroBytes(directory_.data() + 2 * num_blobs,
padded_dir_bytes - 2 * num_blobs * kU128Bytes);
// We already zero-initialized the directory padding;
// `BlobWriter::WriteAll` takes care of padding after each blob via an
// additional I/O.
// Set the second half of directory_ which is the offsets and sizes.
uint64_t offset = size_before_blobs;
for (size_t i = 0; i < num_blobs; ++i) {
HWY_ASSERT(blobs[i].data() != nullptr);
SetRange(i, offset, blobs[i].size());
offset = RoundUpToAlign(offset + blobs[i].size());
SetRange(i, offset, blob_sizes[i]);
offset = RoundUpToAlign(offset + blob_sizes[i]);
}
// When writing new files, we always pad to `kEndAlign`.
HWY_ASSERT(hwy::RoundUpTo(offset, kEndAlign) == header_.file_bytes);
HWY_ASSERT(IsValid(FileSize()));
}
// Must be checked by readers before other methods.
@ -221,7 +282,10 @@ class BlobStore {
}
// Ensure blobs are back to back.
uint64_t expected_offset = PaddedDirEnd(NumBlobs());
const size_t size_before_blobs = BytesBeforeBlobs().size();
const size_t size_after_blobs = BytesAfterBlobs().size();
uint64_t expected_offset = size_before_blobs;
for (size_t key_idx = 0; key_idx < NumBlobs(); ++key_idx) {
uint64_t actual_offset;
size_t bytes;
@ -236,7 +300,7 @@ class BlobStore {
}
// Previously files were not padded to `kEndAlign`, so also allow that.
if (expected_offset != header_.file_bytes &&
hwy::RoundUpTo(expected_offset, kEndAlign) != header_.file_bytes) {
expected_offset + size_after_blobs != header_.file_bytes) {
HWY_WARN("Invalid BlobStore: end of blobs %zu but file size %zu.",
static_cast<size_t>(expected_offset),
static_cast<size_t>(header_.file_bytes));
@ -246,20 +310,71 @@ class BlobStore {
return true; // all OK
}
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO>& writes) const {
const size_t key_idx = 0; // not actually associated with a key/blob
writes.emplace_back(
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
// members are const and BlobIO requires non-const pointers, and they
// are not modified by file writes.
const_cast<Header*>(&header_));
writes.emplace_back(
BlobRange{.offset = sizeof(header_),
.bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_),
.key_idx = key_idx},
const_cast<hwy::uint128_t*>(directory_.data()));
static std::vector<uint8_t> BytesBeforeBlobsV2() {
const Header kFakeHeaderV2 = {
.magic = kMagic,
.num_blobs = 0,
.file_bytes = kEndAlign,
};
std::vector<uint8_t> header(PaddedHeaderAndDirBytes(0));
hwy::CopyBytes(&kFakeHeaderV2, header.data(), sizeof(Header));
return header;
}
std::vector<uint8_t> BytesBeforeBlobs() const {
if (is_file_v2_) {
return BytesBeforeBlobsV2();
} else {
const size_t padded_header_and_directory_size =
PaddedHeaderAndDirBytes(NumBlobs());
std::vector<uint8_t> header_and_directory(
padded_header_and_directory_size);
// Copy header_ at the beginning (offset 0)
hwy::CopyBytes(&header_, header_and_directory.data(), sizeof(header_));
// Copy directory_ immediately after the header_
hwy::CopyBytes(directory_.data(),
header_and_directory.data() + sizeof(header_),
2 * kU128Bytes * NumBlobs());
return header_and_directory;
}
}
std::vector<uint8_t> BytesAfterBlobs() const {
// Gets blob end.
uint64_t offset = 0;
size_t bytes = 0;
GetRange(NumBlobs() - 1, offset, bytes);
const uint64_t blob_end = RoundUpToAlign(offset + bytes);
// For V1, just return the file paddings.
if (!is_file_v2_) {
return std::vector<uint8_t>(FileSize() - blob_end);
}
const size_t header_and_directory_with_file_padding_size =
FileSize() - blob_end;
std::vector<uint8_t> header_and_directory(
header_and_directory_with_file_padding_size);
const size_t header_size = sizeof(Header);
const size_t directory_size = 2 * kU128Bytes * NumBlobs();
// Copy header_ at the end.
offset = header_and_directory_with_file_padding_size - header_size;
hwy::CopyBytes(&header_, header_and_directory.data() + offset, header_size);
// Copy directory_ immediately before the header_.
offset -= directory_size;
hwy::CopyBytes(directory_.data(), header_and_directory.data() + offset,
directory_size);
return header_and_directory;
}
size_t FileSize() const { return header_.file_bytes; }
size_t NumBlobs() const { return static_cast<size_t>(header_.num_blobs); }
// Not the entirety of `directory_`! The second half is offset/size.
@ -291,6 +406,8 @@ class BlobStore {
val.hi = bytes;
}
bool is_file_v2_ = true;
Header header_;
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
@ -303,7 +420,6 @@ BlobReader::BlobReader(const Path& blob_path)
if (file_bytes_ == 0) HWY_ABORT("Zero-sized file %s", blob_path.path.c_str());
BlobStore bs(*file_);
HWY_ASSERT(bs.IsValid(file_bytes_)); // IsValid already printed a warning
keys_.reserve(bs.NumBlobs());
for (const hwy::uint128_t key : bs.Keys()) {
@ -324,8 +440,8 @@ BlobReader::BlobReader(const Path& blob_path)
// Split into chunks for load-balancing even if blob sizes vary.
static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
uint8_t* data, std::vector<BlobIO>& writes) {
constexpr size_t kChunkBytes = 4 * 1024 * 1024;
const uint8_t* data, std::vector<BlobIO>& writes) {
constexpr size_t kChunkBytes = 10 * 1024 * 1024;
const uint64_t end = offset + bytes;
// Split into whole chunks and possibly one remainder.
if (end >= kChunkBytes) {
@ -341,83 +457,42 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
data);
}
// Write a padding if necessary.
static constexpr uint8_t kZeros[kBlobAlign] = {0};
const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes;
if (padding > 0) {
writes.emplace_back(
BlobRange{.offset = end, .bytes = padding, .key_idx = key_idx},
static_cast<const uint8_t*>(kZeros));
}
}
static void EnqueueWritesForBlobs(const BlobStore& bs,
const hwy::Span<const uint8_t> blobs[],
std::vector<uint8_t>& zeros,
std::vector<BlobIO>& writes) {
// All-zero buffer used to write padding to the file without copying the
// input blobs.
static constexpr uint8_t kZeros[kBlobAlign] = {0};
uint64_t file_end = 0; // for padding
for (size_t key_idx = 0; key_idx < bs.NumBlobs(); ++key_idx) {
// We know the size, but `BlobStore` tells us the offset to write each blob.
uint64_t offset;
size_t bytes;
bs.GetRange(key_idx, offset, bytes);
HWY_ASSERT(offset != 0);
HWY_ASSERT(bytes == blobs[key_idx].size());
const uint64_t new_file_end = offset + bytes;
HWY_ASSERT(new_file_end >= file_end); // blobs are ordered by offset
file_end = new_file_end;
EnqueueChunks(key_idx, offset, bytes,
const_cast<uint8_t*>(blobs[key_idx].data()), writes);
const size_t padding = BlobStore::RoundUpToAlign(bytes) - bytes;
if (padding != 0) {
HWY_ASSERT(padding <= kBlobAlign);
writes.emplace_back(
BlobRange{
.offset = offset + bytes, .bytes = padding, .key_idx = key_idx},
const_cast<uint8_t*>(kZeros));
}
}
const size_t padding = hwy::RoundUpTo(file_end, kEndAlign) - file_end;
if (padding != 0) {
// Bigger than `kZeros`, better to allocate than issue multiple I/Os. Must
// remain alive until the last I/O is done.
zeros.resize(padding);
writes.emplace_back(
BlobRange{.offset = file_end, .bytes = padding, .key_idx = 0},
zeros.data());
}
BlobWriter::BlobWriter(const Path& filename, hwy::ThreadPool& pool)
: file_(OpenFileOrNull(filename, "w+")), pool_(pool) {
if (!file_) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
// Write a fake header to the beginning of the file.
std::vector<uint8_t> bytes_before_blobs = BlobStore::BytesBeforeBlobsV2();
file_->Write(bytes_before_blobs.data(), bytes_before_blobs.size(), 0);
}
void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
HWY_ASSERT(data != nullptr);
HWY_ASSERT(bytes != 0);
keys_.push_back(KeyFromString(key.c_str()));
blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes);
}
void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
const size_t num_blobs = keys_.size();
HWY_ASSERT(num_blobs != 0);
HWY_ASSERT(num_blobs == blobs_.size());
blob_sizes_.push_back(bytes);
std::vector<BlobIO> writes;
writes.reserve(16384);
const BlobStore bs(num_blobs, keys_.data(), blobs_.data());
bs.EnqueueWriteForHeaderAndDirectory(writes);
std::vector<uint8_t> zeros;
EnqueueWritesForBlobs(bs, blobs_.data(), zeros, writes);
// Create/replace existing file.
std::unique_ptr<File> file = OpenFileOrNull(filename, "w+");
if (!file) HWY_ABORT("Failed to open for writing %s", filename.path.c_str());
EnqueueChunks(keys_.size() - 1, file_->FileSize(), bytes,
static_cast<const uint8_t*>(data), writes);
hwy::ThreadPool null_pool(0);
hwy::ThreadPool& pool_or_serial = file->IsAppendOnly() ? null_pool : pool;
hwy::ThreadPool& pool_or_serial = file_->IsAppendOnly() ? null_pool : pool_;
pool_or_serial.Run(
0, writes.size(), [this, &file, &writes](uint64_t i, size_t /*thread*/) {
0, writes.size(), [this, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange& range = writes[i].range;
if (!file->Write(writes[i].data, range.bytes, range.offset)) {
if (!file_->Write(writes[i].data, range.bytes, range.offset)) {
const std::string& key = StringFromKey(keys_[range.key_idx]);
HWY_ABORT("Write failed for %s from %zu, %zu bytes to %p.",
key.c_str(), static_cast<size_t>(range.offset), range.bytes,
@ -426,4 +501,13 @@ void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
});
}
void BlobWriter::WriteAll() {
const BlobStore bs = BlobStore(keys_, blob_sizes_);
// Write the rest of the bytes, which contains: paddings + directory + header.
const auto bytes_after_blobs = bs.BytesAfterBlobs();
file_->Write(bytes_after_blobs.data(), bytes_after_blobs.size(),
file_->FileSize());
}
} // namespace gcpp

View File

@ -116,6 +116,8 @@ class BlobReader {
// does not make sense to call the methods concurrently.
class BlobWriter {
public:
explicit BlobWriter(const Path& filename, hwy::ThreadPool& pool);
void Add(const std::string& key, const void* data, size_t bytes);
// For `ModelStore`: this is the `key_idx` of the next blob to be added.
@ -123,11 +125,13 @@ class BlobWriter {
// Stores all blobs to disk in the given order with padding for alignment.
// Aborts on error.
void WriteAll(hwy::ThreadPool& pool, const Path& filename);
void WriteAll();
private:
std::unique_ptr<File> file_;
std::vector<hwy::uint128_t> keys_;
std::vector<hwy::Span<const uint8_t>> blobs_;
std::vector<size_t> blob_sizes_;
hwy::ThreadPool& pool_;
};
} // namespace gcpp

View File

@ -52,10 +52,10 @@ TEST(BlobStoreTest, TestReadWrite) {
const std::string keyA("0123456789abcdef"); // max 16 characters
const std::string keyB("q");
BlobWriter writer;
BlobWriter writer(path, pool);
writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer));
writer.WriteAll(pool, path);
writer.WriteAll();
HWY_ASSERT_ARRAY_EQ(kOriginalData.data(), buffer.data(), buffer.size());
std::fill(buffer.begin(), buffer.end(), 0);
@ -69,12 +69,12 @@ TEST(BlobStoreTest, TestReadWrite) {
const BlobRange* range = reader.Find(keyA);
HWY_ASSERT(range);
const uint64_t offsetA = range->offset;
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
HWY_ASSERT_EQ(offsetA, 256);
HWY_ASSERT_EQ(range->bytes, 5);
range = reader.Find(keyB);
HWY_ASSERT(range);
const uint64_t offsetB = range->offset;
HWY_ASSERT_EQ(offsetB, 2 * 256);
HWY_ASSERT_EQ(offsetB, offsetA + 256);
HWY_ASSERT_EQ(range->bytes, sizeof(buffer));
HWY_ASSERT(
@ -106,7 +106,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(fd > 0);
const Path path(path_str);
BlobWriter writer;
BlobWriter writer(path, pool);
std::vector<std::string> keys;
keys.reserve(num_blobs);
std::vector<std::vector<uint8_t>> blobs;
@ -126,7 +126,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
}
HWY_ASSERT(keys.size() == num_blobs);
HWY_ASSERT(blobs.size() == num_blobs);
writer.WriteAll(pool, path);
writer.WriteAll();
BlobReader reader(path);
HWY_ASSERT_EQ(reader.Keys().size(), num_blobs);