mirror of https://github.com/google/gemma.cpp.git
Serialization for class members for use with ModelConfig
PiperOrigin-RevId: 689720027
This commit is contained in:
parent
efff64605a
commit
52af531820
|
|
@ -45,6 +45,8 @@ set(SOURCES
|
||||||
compression/compress.cc
|
compression/compress.cc
|
||||||
compression/compress.h
|
compression/compress.h
|
||||||
compression/compress-inl.h
|
compression/compress-inl.h
|
||||||
|
compression/fields.cc
|
||||||
|
compression/fields.h
|
||||||
compression/io_win.cc
|
compression/io_win.cc
|
||||||
compression/io.cc
|
compression/io.cc
|
||||||
compression/io.h
|
compression/io.h
|
||||||
|
|
@ -150,6 +152,7 @@ set(GEMMA_TEST_FILES
|
||||||
compression/blob_store_test.cc
|
compression/blob_store_test.cc
|
||||||
compression/compress_test.cc
|
compression/compress_test.cc
|
||||||
compression/distortion_test.cc
|
compression/distortion_test.cc
|
||||||
|
compression/fields_test.cc
|
||||||
compression/nuq_test.cc
|
compression/nuq_test.cc
|
||||||
compression/sfp_test.cc
|
compression/sfp_test.cc
|
||||||
evals/gemma_test.cc
|
evals/gemma_test.cc
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,26 @@ cc_library(
|
||||||
] + FILE_DEPS,
|
] + FILE_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "fields",
|
||||||
|
srcs = ["fields.cc"],
|
||||||
|
hdrs = ["fields.h"],
|
||||||
|
deps = [
|
||||||
|
"@highway//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "fields_test",
|
||||||
|
srcs = ["fields_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":fields",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:hwy_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "blob_store",
|
name = "blob_store",
|
||||||
srcs = ["blob_store.cc"],
|
srcs = ["blob_store.cc"],
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,320 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "compression/fields.h"
|
||||||
|
|
||||||
|
#include <stdarg.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
IFieldsVisitor::~IFieldsVisitor() = default;
|
||||||
|
|
||||||
|
void IFieldsVisitor::NotifyInvalid(const char* fmt, ...) {
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
vfprintf(stderr, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
|
||||||
|
any_invalid_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
class VisitorBase : public IFieldsVisitor {
|
||||||
|
public:
|
||||||
|
VisitorBase() = default;
|
||||||
|
~VisitorBase() override = default;
|
||||||
|
|
||||||
|
// This is the only call site of IFields::VisitFields.
|
||||||
|
void operator()(IFields& fields) override { fields.VisitFields(*this); }
|
||||||
|
|
||||||
|
protected: // Functions shared between ReadVisitor and WriteVisitor:
|
||||||
|
void CheckF32(float value) {
|
||||||
|
if (HWY_UNLIKELY(hwy::ScalarIsInf(value) || hwy::ScalarIsNaN(value))) {
|
||||||
|
NotifyInvalid("Invalid float %g\n", value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return bool to avoid having to check AnyInvalid() after calling.
|
||||||
|
bool CheckStringLength(uint32_t num_u32) {
|
||||||
|
// Disallow long strings for safety, and to prevent them being used for
|
||||||
|
// arbitrary data (we also require them to be ASCII).
|
||||||
|
if (HWY_UNLIKELY(num_u32 > 64)) {
|
||||||
|
NotifyInvalid("String num_u32=%u too large\n", num_u32);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckStringU32(uint32_t u32, uint32_t i, uint32_t num_u32) {
|
||||||
|
// Although strings are zero-padded to u32, an entire u32 should not be
|
||||||
|
// zero, and upper bits should not be set (ASCII-only).
|
||||||
|
if (HWY_UNLIKELY(u32 == 0 || (u32 & 0x80808080))) {
|
||||||
|
NotifyInvalid("Invalid characters %x at %u of %u\n", u32, i, num_u32);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class PrintVisitor : public VisitorBase {
|
||||||
|
public:
|
||||||
|
void operator()(uint32_t& value) override {
|
||||||
|
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(float& value) override {
|
||||||
|
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(std::string& value) override {
|
||||||
|
fprintf(stderr, "%sStr %s\n", indent_.c_str(), value.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(IFields& fields) override {
|
||||||
|
fprintf(stderr, "%s%s\n", indent_.c_str(), fields.Name());
|
||||||
|
indent_ += " ";
|
||||||
|
|
||||||
|
VisitorBase::operator()(fields);
|
||||||
|
|
||||||
|
HWY_ASSERT(!indent_.empty());
|
||||||
|
indent_.resize(indent_.size() - 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string indent_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ReadVisitor : public VisitorBase {
|
||||||
|
public:
|
||||||
|
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos)
|
||||||
|
: span_(span), result_(pos) {}
|
||||||
|
~ReadVisitor() {
|
||||||
|
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
|
||||||
|
}
|
||||||
|
|
||||||
|
// All data is read through this overload.
|
||||||
|
void operator()(uint32_t& value) override {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
value = span_[result_.pos++];
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(float& value) override {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
uint32_t u32 = hwy::BitCastScalar<uint32_t>(value);
|
||||||
|
operator()(u32);
|
||||||
|
value = hwy::BitCastScalar<float>(u32);
|
||||||
|
CheckF32(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(std::string& value) override {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
uint32_t num_u32; // not including itself because this..
|
||||||
|
operator()(num_u32); // increments result_.pos for the num_u32 field
|
||||||
|
if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return;
|
||||||
|
|
||||||
|
// Ensure we have that much data.
|
||||||
|
if (HWY_UNLIKELY(result_.pos + num_u32 > end_.back())) {
|
||||||
|
NotifyInvalid("Invalid string: pos %zu + num_u32 %u > end %zu\n",
|
||||||
|
result_.pos, num_u32, span_.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr size_t k4 = sizeof(uint32_t);
|
||||||
|
value.resize(num_u32 * k4);
|
||||||
|
for (uint32_t i = 0; i < num_u32; ++i) {
|
||||||
|
uint32_t u32;
|
||||||
|
operator()(u32);
|
||||||
|
(void)CheckStringU32(u32, i, num_u32);
|
||||||
|
hwy::CopyBytes(&u32, value.data() + i * k4, k4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim 0..3 trailing nulls.
|
||||||
|
const size_t pos = value.find_last_not_of('\0');
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
value.resize(pos + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(IFields& fields) override {
|
||||||
|
// Our SkipField requires end_ to be set before reading num_u32, which
|
||||||
|
// determines the actual end, so use an upper bound which is tight if this
|
||||||
|
// IFields is last one in span_.
|
||||||
|
end_.push_back(span_.size());
|
||||||
|
|
||||||
|
if (HWY_UNLIKELY(SkipField())) {
|
||||||
|
end_.pop_back(); // undo `push_back` to keep the stack balanced
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t num_u32; // not including itself because this..
|
||||||
|
operator()(num_u32); // increments result_.pos for the num_u32 field
|
||||||
|
|
||||||
|
// Ensure we have that much data and set end_.
|
||||||
|
if (HWY_UNLIKELY(result_.pos + num_u32 > span_.size())) {
|
||||||
|
NotifyInvalid("Invalid IFields: pos %zu + num_u32 %u > size %zu\n",
|
||||||
|
result_.pos, num_u32, span_.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
end_.back() = result_.pos + num_u32;
|
||||||
|
|
||||||
|
VisitorBase::operator()(fields);
|
||||||
|
|
||||||
|
HWY_ASSERT(!end_.empty() && result_.pos <= end_.back());
|
||||||
|
// Count extra, which indicates old code and new data.
|
||||||
|
result_.extra_u32 += end_.back() - result_.pos;
|
||||||
|
end_.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override because ReadVisitor also does bounds checking.
|
||||||
|
bool SkipField() override {
|
||||||
|
// If invalid, all bets are off and we don't count missing fields.
|
||||||
|
if (HWY_UNLIKELY(AnyInvalid())) return true;
|
||||||
|
|
||||||
|
// Reaching the end of the stored size, or the span, is not invalid -
|
||||||
|
// it happens when we read old data with new code.
|
||||||
|
if (HWY_UNLIKELY(result_.pos >= end_.back())) {
|
||||||
|
result_.missing_fields++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override so that operator()(std::vector<T>&) resizes the vector.
|
||||||
|
bool IsReading() const override { return true; }
|
||||||
|
|
||||||
|
IFields::ReadResult Result() {
|
||||||
|
if (HWY_UNLIKELY(AnyInvalid())) result_.pos = 0;
|
||||||
|
return result_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const hwy::Span<const uint32_t> span_;
|
||||||
|
IFields::ReadResult result_;
|
||||||
|
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
|
||||||
|
// but read in SkipField.
|
||||||
|
std::vector<uint32_t> end_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class WriteVisitor : public VisitorBase {
|
||||||
|
public:
|
||||||
|
WriteVisitor(std::vector<uint32_t>& storage) : storage_(storage) {}
|
||||||
|
|
||||||
|
// Note: while writing, only string lengths/characters can trigger AnyInvalid,
|
||||||
|
// so we don't have to check SkipField.
|
||||||
|
|
||||||
|
void operator()(uint32_t& value) override { storage_.push_back(value); }
|
||||||
|
|
||||||
|
void operator()(float& value) override {
|
||||||
|
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
|
||||||
|
CheckF32(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(std::string& value) override {
|
||||||
|
constexpr size_t k4 = sizeof(uint32_t);
|
||||||
|
|
||||||
|
// Write length.
|
||||||
|
uint32_t num_u32 = hwy::DivCeil(value.size(), k4);
|
||||||
|
if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return;
|
||||||
|
operator()(num_u32); // always valid
|
||||||
|
|
||||||
|
// Copy whole uint32_t.
|
||||||
|
const size_t num_whole_u32 = value.size() / k4;
|
||||||
|
for (uint32_t i = 0; i < num_whole_u32; ++i) {
|
||||||
|
uint32_t u32 = 0;
|
||||||
|
hwy::CopyBytes(value.data() + i * k4, &u32, k4);
|
||||||
|
if (HWY_UNLIKELY(!CheckStringU32(u32, i, num_u32))) return;
|
||||||
|
storage_.push_back(u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read remaining bytes into least-significant bits of u32.
|
||||||
|
const size_t remainder = value.size() - num_whole_u32 * k4;
|
||||||
|
if (remainder != 0) {
|
||||||
|
HWY_DASSERT(remainder < k4);
|
||||||
|
uint32_t u32 = 0;
|
||||||
|
for (size_t i = 0; i < remainder; ++i) {
|
||||||
|
const char c = value[num_whole_u32 * k4 + i];
|
||||||
|
const uint32_t next = static_cast<uint32_t>(static_cast<uint8_t>(c));
|
||||||
|
u32 += next << (i * 8);
|
||||||
|
}
|
||||||
|
if (HWY_UNLIKELY(!CheckStringU32(u32, num_whole_u32, num_u32))) return;
|
||||||
|
storage_.push_back(u32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(IFields& fields) override {
|
||||||
|
const size_t pos_before_size = storage_.size();
|
||||||
|
storage_.push_back(0); // placeholder, updated below
|
||||||
|
|
||||||
|
VisitorBase::operator()(fields);
|
||||||
|
|
||||||
|
HWY_ASSERT(storage_[pos_before_size] == 0);
|
||||||
|
// Number of u32 written, including the one storing that number.
|
||||||
|
const uint32_t num_u32 = storage_.size() - pos_before_size;
|
||||||
|
HWY_ASSERT(num_u32 != 0); // at least one due to push_back above
|
||||||
|
// Store the payload size, not including the num_u32 field itself, because
|
||||||
|
// that is more convenient for ReadVisitor.
|
||||||
|
storage_[pos_before_size] = num_u32 - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<uint32_t>& storage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
IFields::~IFields() = default;
|
||||||
|
|
||||||
|
void IFields::Print() const {
|
||||||
|
PrintVisitor visitor;
|
||||||
|
// VisitFields is non-const. It is safe to cast because PrintVisitor does not
|
||||||
|
// modify the fields.
|
||||||
|
visitor(*const_cast<IFields*>(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span,
|
||||||
|
size_t pos) {
|
||||||
|
ReadVisitor visitor(span, pos);
|
||||||
|
visitor(*this);
|
||||||
|
return visitor.Result();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IFields::AppendTo(std::vector<uint32_t>& storage) const {
|
||||||
|
// VisitFields is non-const. It is safe to cast because WriteVisitor does not
|
||||||
|
// modify the fields.
|
||||||
|
IFields& fields = *const_cast<IFields*>(this);
|
||||||
|
|
||||||
|
// Reduce allocations, but not in debug builds so we notice any iterator
|
||||||
|
// invalidation bugs.
|
||||||
|
if constexpr (!HWY_IS_DEBUG_BUILD) {
|
||||||
|
storage.reserve(storage.size() + 256);
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteVisitor visitor(storage);
|
||||||
|
visitor(fields);
|
||||||
|
return !visitor.AnyInvalid();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,200 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
|
||||||
|
|
||||||
|
// Simple serialization/deserialization for user-defined classes, inspired by
|
||||||
|
// BSD-licensed code Copyright (c) the JPEG XL Project Authors:
|
||||||
|
// https://github.com/libjxl/libjxl, lib/jxl/fields.h.
|
||||||
|
|
||||||
|
// IWYU pragma: begin_exports
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Design goals:
|
||||||
|
// - self-contained to simplify installing/building, no separate compiler (rules
|
||||||
|
// out Protocol Buffers, FlatBuffers, Cap'n Proto, Apache Thrift).
|
||||||
|
// - simplicity: small codebase without JIT (rules out Apache Fury and bitsery).
|
||||||
|
// - old code can read new data, and new code can read old data (rules out yas
|
||||||
|
// and msgpack). This avoids rewriting weights when we add a new field.
|
||||||
|
// - no user-specified versions (rules out cereal) nor field names (rules out
|
||||||
|
// JSON and GGUF). These are error-prone; users should just be able to append
|
||||||
|
// new fields.
|
||||||
|
//
|
||||||
|
// Non-goals:
|
||||||
|
// - anything better than reasonable encoding size and decode speed: we only
|
||||||
|
// anticipate ~KiB of data, alongside ~GiB of separately compressed weights.
|
||||||
|
// - features such as maps, interfaces, random access, and optional/deleted
|
||||||
|
// fields: not required for the intended use case of `ModelConfig`.
|
||||||
|
// - support any other languages than C++ and Python (for the exporter).
|
||||||
|
|
||||||
|
class IFields; // breaks circular dependency
|
||||||
|
|
||||||
|
// Visitors are internal-only, but their base class is visible to user code
|
||||||
|
// because their `IFields::VisitFields` calls `visitor.operator()`.
|
||||||
|
//
|
||||||
|
// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes
|
||||||
|
// derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
|
||||||
|
class IFieldsVisitor {
|
||||||
|
public:
|
||||||
|
virtual ~IFieldsVisitor();
|
||||||
|
|
||||||
|
// Indicates whether NotifyInvalid was called for any field. Once set, this is
|
||||||
|
// sticky for all IFields visited by this visitor.
|
||||||
|
bool AnyInvalid() const { return any_invalid_; }
|
||||||
|
|
||||||
|
// None of these fail directly, but they call NotifyInvalid() if any value
|
||||||
|
// is out of range. A single generic/overloaded function is required to
|
||||||
|
// support `std::vector<T>`.
|
||||||
|
virtual void operator()(uint32_t& value) = 0;
|
||||||
|
virtual void operator()(float& value) = 0;
|
||||||
|
virtual void operator()(std::string& value) = 0;
|
||||||
|
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
|
||||||
|
|
||||||
|
// bool and enum fields are actually stored as uint32_t.
|
||||||
|
void operator()(bool& value) {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
uint32_t u32 = value ? 1 : 0;
|
||||||
|
operator()(u32);
|
||||||
|
if (HWY_UNLIKELY(u32 > 1)) {
|
||||||
|
return NotifyInvalid("Invalid bool %u\n", u32);
|
||||||
|
}
|
||||||
|
value = (u32 == 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename EnumT, hwy::EnableIf<std::is_enum_v<EnumT>>* = nullptr>
|
||||||
|
void operator()(EnumT& value) {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
uint32_t u32 = static_cast<uint32_t>(value);
|
||||||
|
operator()(u32);
|
||||||
|
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
|
||||||
|
return NotifyInvalid("Invalid enum %u\n");
|
||||||
|
}
|
||||||
|
value = static_cast<EnumT>(u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(std::vector<T>& value) {
|
||||||
|
if (HWY_UNLIKELY(SkipField())) return;
|
||||||
|
|
||||||
|
uint32_t num = static_cast<uint32_t>(value.size());
|
||||||
|
operator()(num);
|
||||||
|
if (HWY_UNLIKELY(num > 64 * 1024)) {
|
||||||
|
return NotifyInvalid("Vector too long %u\n", num);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (IsReading()) {
|
||||||
|
value.resize(num);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < value.size(); ++i) {
|
||||||
|
operator()(value[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Prints a message and causes subsequent AnyInvalid() to return true.
|
||||||
|
void NotifyInvalid(const char* fmt, ...);
|
||||||
|
|
||||||
|
// Must check this before modifying any field, and if it returns true,
|
||||||
|
// avoid doing so. This is important for strings and vectors in the
|
||||||
|
// "new code, old data" test: resizing them may destroy their contents.
|
||||||
|
virtual bool SkipField() { return AnyInvalid(); }
|
||||||
|
// For operator()(std::vector&).
|
||||||
|
virtual bool IsReading() const { return false; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool any_invalid_ = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Abstract base class for user-defined serializable classes, which are
|
||||||
|
// forward- and backward compatible collection of fields (members). This means
|
||||||
|
// old code can safely read new data, and new code can still handle old data.
|
||||||
|
//
|
||||||
|
// Fields are written in the unchanging order established by the user-defined
|
||||||
|
// `VisitFields`; any new fields must be visited after all existing fields in
|
||||||
|
// the same `IFields`. We encode each into `uint32_t` storage for simplicity.
|
||||||
|
//
|
||||||
|
// HOWTO:
|
||||||
|
// - basic usage: define a struct with member variables ("fields") and their
|
||||||
|
// initializers, e.g. `uint32_t field = 0;`. Then define a
|
||||||
|
// `void VisitFields(IFieldsVisitor& v)` member function that calls
|
||||||
|
// `v(field);` etc. for each field, and a `const char* Name()` function used
|
||||||
|
// as a caption when printing.
|
||||||
|
//
|
||||||
|
// - enum fields: define `enum class EnumT` and `bool EnumValid(EnumT)`, then
|
||||||
|
// call `v(field);` as usual. Note that `EnumT` is not extendable insofar as
|
||||||
|
// `EnumValid` returns false for values beyond the initially known ones. You
|
||||||
|
// can add placeholders, which requires user code to know how to handle them,
|
||||||
|
// or later add new fields including enums to override the first enum.
|
||||||
|
struct IFields {
|
||||||
|
virtual ~IFields();
|
||||||
|
|
||||||
|
// User-defined caption used during Print().
|
||||||
|
virtual const char* Name() const = 0;
|
||||||
|
|
||||||
|
// User-defined, called by IFieldsVisitor::operator()(IFields&).
|
||||||
|
virtual void VisitFields(IFieldsVisitor& visitor) = 0;
|
||||||
|
|
||||||
|
// Prints name and fields to stderr.
|
||||||
|
void Print() const;
|
||||||
|
|
||||||
|
struct ReadResult {
|
||||||
|
ReadResult(size_t pos) : pos(pos), missing_fields(0), extra_u32(0) {}
|
||||||
|
|
||||||
|
// Where to resume reading in the next Read() call, or 0 if there was an
|
||||||
|
// unrecoverable error: any field has an invalid value, or the span is
|
||||||
|
// shorter than the data says it should be. If so, do not use the fields nor
|
||||||
|
// continue reading.
|
||||||
|
size_t pos;
|
||||||
|
// From the perspective of VisitFields, how many more fields would have
|
||||||
|
// been read beyond the stored size. If non-zero, the data is older than
|
||||||
|
// the code, but valid, and extra_u32 should be zero.
|
||||||
|
uint32_t missing_fields;
|
||||||
|
// How many extra u32 are in the stored size, vs. what we actually read as
|
||||||
|
// requested by VisitFields. If non-zero,, the data is newer than the code,
|
||||||
|
// but valid, and missing_fields should be zero.
|
||||||
|
uint32_t extra_u32;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reads fields starting at `span[pos]`.
|
||||||
|
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos);
|
||||||
|
|
||||||
|
// Returns false if there was an unrecoverable error, typically because a
|
||||||
|
// field has an invalid value. If so, `storage` is undefined.
|
||||||
|
bool AppendTo(std::vector<uint32_t>& storage) const;
|
||||||
|
|
||||||
|
// Convenience wrapper for AppendTo when we only write once.
|
||||||
|
std::vector<uint32_t> Write() const {
|
||||||
|
std::vector<uint32_t> storage;
|
||||||
|
if (!AppendTo(storage)) storage.clear();
|
||||||
|
return storage;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
|
||||||
|
|
@ -0,0 +1,354 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "compression/fields.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#if !HWY_TEST_STANDALONE
|
||||||
|
class FieldsTest : public testing::Test {};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void MaybePrint(const IFields& fields) {
|
||||||
|
if (HWY_IS_DEBUG_BUILD) {
|
||||||
|
fields.Print();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void CheckVectorEqual(const std::vector<T>& a, const std::vector<T>& b) {
|
||||||
|
EXPECT_EQ(a.size(), b.size());
|
||||||
|
for (size_t i = 0; i < a.size(); ++i) {
|
||||||
|
if constexpr (std::is_base_of_v<IFields, T>) {
|
||||||
|
a[i].CheckEqual(b[i]);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(a[i], b[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class Enum : uint32_t {
|
||||||
|
k1 = 1,
|
||||||
|
k3 = 3,
|
||||||
|
k8 = 8,
|
||||||
|
};
|
||||||
|
HWY_MAYBE_UNUSED bool EnumValid(Enum e) {
|
||||||
|
return e == Enum::k1 || e == Enum::k3 || e == Enum::k8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains all supported types except IFields and std::vector<IFields>.
|
||||||
|
struct Nested : public IFields {
|
||||||
|
Nested() : nested_u32(0) {} // for std::vector
|
||||||
|
explicit Nested(uint32_t u32) : nested_u32(u32) {}
|
||||||
|
|
||||||
|
const char* Name() const override { return "Nested"; }
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
visitor(nested_u32);
|
||||||
|
visitor(nested_bool);
|
||||||
|
visitor(nested_vector);
|
||||||
|
visitor(nested_enum);
|
||||||
|
visitor(nested_str);
|
||||||
|
visitor(nested_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckEqual(const Nested& n) const {
|
||||||
|
EXPECT_EQ(nested_u32, n.nested_u32);
|
||||||
|
EXPECT_EQ(nested_bool, n.nested_bool);
|
||||||
|
CheckVectorEqual(nested_vector, n.nested_vector);
|
||||||
|
EXPECT_EQ(nested_enum, n.nested_enum);
|
||||||
|
EXPECT_EQ(nested_str, n.nested_str);
|
||||||
|
EXPECT_EQ(nested_f, n.nested_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t nested_u32; // set in ctor
|
||||||
|
bool nested_bool = true;
|
||||||
|
std::vector<uint32_t> nested_vector = {1, 2, 3};
|
||||||
|
Enum nested_enum = Enum::k1;
|
||||||
|
std::string nested_str = "nested";
|
||||||
|
float nested_f = 1.125f;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Contains all supported types.
|
||||||
|
struct OldFields : public IFields {
|
||||||
|
const char* Name() const override { return "OldFields"; }
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
visitor(old_str);
|
||||||
|
visitor(old_nested);
|
||||||
|
visitor(old1);
|
||||||
|
visitor(old_vec_str);
|
||||||
|
visitor(old_vec_nested);
|
||||||
|
visitor(old_f);
|
||||||
|
visitor(old_enum);
|
||||||
|
visitor(old_bool);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Template allows comparing with NewFields.
|
||||||
|
template <typename Other>
|
||||||
|
void CheckEqual(const Other& n) const {
|
||||||
|
EXPECT_EQ(old_str, n.old_str);
|
||||||
|
old_nested.CheckEqual(n.old_nested);
|
||||||
|
EXPECT_EQ(old1, n.old1);
|
||||||
|
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
||||||
|
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
||||||
|
EXPECT_EQ(old_f, n.old_f);
|
||||||
|
EXPECT_EQ(old_enum, n.old_enum);
|
||||||
|
EXPECT_EQ(old_bool, n.old_bool);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string old_str = "old";
|
||||||
|
Nested old_nested = Nested(0);
|
||||||
|
uint32_t old1 = 1;
|
||||||
|
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||||
|
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||||
|
float old_f = 1.125f;
|
||||||
|
Enum old_enum = Enum::k1;
|
||||||
|
bool old_bool = true;
|
||||||
|
}; // OldFields
|
||||||
|
|
||||||
|
// Simulates adding new fields of all types to an existing struct.
|
||||||
|
struct NewFields : public IFields {
|
||||||
|
const char* Name() const override { return "NewFields"; }
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
visitor(old_str);
|
||||||
|
visitor(old_nested);
|
||||||
|
visitor(old1);
|
||||||
|
visitor(old_vec_str);
|
||||||
|
visitor(old_vec_nested);
|
||||||
|
visitor(old_f);
|
||||||
|
visitor(old_enum);
|
||||||
|
visitor(old_bool);
|
||||||
|
|
||||||
|
// Change order of field types relative to OldFields to ensure that works.
|
||||||
|
visitor(new_nested);
|
||||||
|
visitor(new_bool);
|
||||||
|
visitor(new_vec_nested);
|
||||||
|
visitor(new_f);
|
||||||
|
visitor(new_vec_str);
|
||||||
|
visitor(new_enum);
|
||||||
|
visitor(new2);
|
||||||
|
visitor(new_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckEqual(const NewFields& n) const {
|
||||||
|
EXPECT_EQ(old_str, n.old_str);
|
||||||
|
old_nested.CheckEqual(n.old_nested);
|
||||||
|
EXPECT_EQ(old1, n.old1);
|
||||||
|
CheckVectorEqual(old_vec_str, n.old_vec_str);
|
||||||
|
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
|
||||||
|
EXPECT_EQ(old_f, n.old_f);
|
||||||
|
EXPECT_EQ(old_enum, n.old_enum);
|
||||||
|
EXPECT_EQ(old_bool, n.old_bool);
|
||||||
|
|
||||||
|
new_nested.CheckEqual(n.new_nested);
|
||||||
|
EXPECT_EQ(new_bool, n.new_bool);
|
||||||
|
CheckVectorEqual(new_vec_nested, n.new_vec_nested);
|
||||||
|
EXPECT_EQ(new_f, n.new_f);
|
||||||
|
CheckVectorEqual(new_vec_str, n.new_vec_str);
|
||||||
|
EXPECT_EQ(new_enum, n.new_enum);
|
||||||
|
EXPECT_EQ(new2, n.new2);
|
||||||
|
EXPECT_EQ(new_str, n.new_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copied from OldFields to match the use case of adding new fields. If we
|
||||||
|
// write an OldFields member, that would change the layout due to its size.
|
||||||
|
std::string old_str = "old";
|
||||||
|
Nested old_nested = Nested(0);
|
||||||
|
uint32_t old1 = 1;
|
||||||
|
std::vector<std::string> old_vec_str = {"abc", "1234"};
|
||||||
|
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
|
||||||
|
float old_f = 1.125f;
|
||||||
|
Enum old_enum = Enum::k1;
|
||||||
|
bool old_bool = true;
|
||||||
|
|
||||||
|
Nested new_nested = Nested(999);
|
||||||
|
bool new_bool = false;
|
||||||
|
std::vector<Nested> new_vec_nested = {Nested(2), Nested(3)};
|
||||||
|
float new_f = -2.0f;
|
||||||
|
std::vector<std::string> new_vec_str = {"AB", std::string(), "56789"};
|
||||||
|
Enum new_enum = Enum::k3;
|
||||||
|
uint32_t new2 = 2;
|
||||||
|
std::string new_str = std::string(); // empty is allowed
|
||||||
|
}; // NewFields
|
||||||
|
|
||||||
|
// Changes all fields to non-default values.
|
||||||
|
NewFields ModifiedNewFields() {
|
||||||
|
NewFields n;
|
||||||
|
n.old_str = "old2";
|
||||||
|
n.old_nested = Nested(5);
|
||||||
|
n.old1 = 11;
|
||||||
|
n.old_vec_str = {"abc2", "431", "ZZ"};
|
||||||
|
n.old_vec_nested = {Nested(9)};
|
||||||
|
n.old_f = -2.5f;
|
||||||
|
n.old_enum = Enum::k3;
|
||||||
|
n.old_bool = false;
|
||||||
|
|
||||||
|
n.new_nested = Nested(55);
|
||||||
|
n.new_bool = true;
|
||||||
|
n.new_vec_nested = {Nested(3), Nested(33), Nested(333)};
|
||||||
|
n.new_f = 4.f;
|
||||||
|
n.new_vec_str = {"4321", "321", "21", "1"};
|
||||||
|
n.new_enum = Enum::k8;
|
||||||
|
n.new2 = 22;
|
||||||
|
n.new_str = "new and even longer";
|
||||||
|
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
using Span = hwy::Span<const uint32_t>;
|
||||||
|
|
||||||
|
using ReadResult = IFields::ReadResult;
|
||||||
|
void CheckConsumedAll(const ReadResult& result, size_t storage_size) {
|
||||||
|
EXPECT_NE(0, storage_size); // Ensure we notice failure (pos == 0).
|
||||||
|
EXPECT_EQ(storage_size, result.pos);
|
||||||
|
EXPECT_EQ(0, result.missing_fields);
|
||||||
|
EXPECT_EQ(0, result.extra_u32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we do not change any fields, Write+Read returns the defaults.
|
||||||
|
TEST(FieldsTest, TestNewMatchesDefaults) {
|
||||||
|
NewFields new_fields;
|
||||||
|
const std::vector<uint32_t> storage = new_fields.Write();
|
||||||
|
|
||||||
|
const ReadResult result = new_fields.Read(Span(storage), 0);
|
||||||
|
CheckConsumedAll(result, storage.size());
|
||||||
|
|
||||||
|
NewFields().CheckEqual(new_fields);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change fields from default and check that Write+Read returns them again.
|
||||||
|
TEST(FieldsTest, TestRoundTrip) {
|
||||||
|
NewFields new_fields = ModifiedNewFields();
|
||||||
|
const std::vector<uint32_t> storage = new_fields.Write();
|
||||||
|
|
||||||
|
NewFields copy;
|
||||||
|
const ReadResult result = copy.Read(Span(storage), 0);
|
||||||
|
CheckConsumedAll(result, storage.size());
|
||||||
|
|
||||||
|
new_fields.CheckEqual(copy);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refuse to write invalid floats.
|
||||||
|
TEST(FieldsTest, TestInvalidFloat) {
|
||||||
|
NewFields new_fields;
|
||||||
|
new_fields.new_f = std::numeric_limits<float>::infinity();
|
||||||
|
EXPECT_TRUE(new_fields.Write().empty());
|
||||||
|
|
||||||
|
new_fields.new_f = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
EXPECT_TRUE(new_fields.Write().empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refuse to write invalid strings.
|
||||||
|
TEST(FieldsTest, TestInvalidString) {
|
||||||
|
NewFields new_fields;
|
||||||
|
// Four zero bytes
|
||||||
|
new_fields.new_str.assign(4, '\0');
|
||||||
|
EXPECT_TRUE(new_fields.Write().empty());
|
||||||
|
|
||||||
|
// Too long
|
||||||
|
new_fields.new_str.assign(257, 'a');
|
||||||
|
EXPECT_TRUE(new_fields.Write().empty());
|
||||||
|
|
||||||
|
// First byte not ASCII
|
||||||
|
new_fields.new_str.assign("123");
|
||||||
|
new_fields.new_str[0] = 128;
|
||||||
|
EXPECT_TRUE(new_fields.Write().empty());
|
||||||
|
|
||||||
|
// Upper byte in later u32 not ASCII
|
||||||
|
new_fields.new_str.assign("ABCDEFGH");
|
||||||
|
new_fields.new_str[7] = 255;
|
||||||
|
EXPECT_TRUE(new_fields.Write().empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write two structs to the same storage.
|
||||||
|
TEST(FieldsTest, TestMultipleWrite) {
|
||||||
|
const NewFields modified = ModifiedNewFields();
|
||||||
|
std::vector<uint32_t> storage = modified.Write();
|
||||||
|
const size_t modified_size = storage.size();
|
||||||
|
const NewFields defaults;
|
||||||
|
defaults.AppendTo(storage);
|
||||||
|
|
||||||
|
// Start with defaults to ensure Read retrieves the modified values.
|
||||||
|
NewFields modified_copy;
|
||||||
|
const ReadResult result1 = modified_copy.Read(Span(storage), 0);
|
||||||
|
CheckConsumedAll(result1, modified_size);
|
||||||
|
modified.CheckEqual(modified_copy);
|
||||||
|
|
||||||
|
// Start with modified values to ensure Read retrieves the defaults.
|
||||||
|
NewFields defaults_copy = modified;
|
||||||
|
const ReadResult result2 = defaults_copy.Read(Span(storage), result1.pos);
|
||||||
|
CheckConsumedAll(result2, storage.size());
|
||||||
|
defaults.CheckEqual(defaults_copy);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write old defaults, read old using new code.
|
||||||
|
TEST(FieldsTest, TestNewCodeOldData) {
|
||||||
|
OldFields old_fields;
|
||||||
|
const std::vector<uint32_t> storage = old_fields.Write();
|
||||||
|
|
||||||
|
// Start with modified old values to ensure old defaults overwrite them.
|
||||||
|
NewFields new_fields = ModifiedNewFields();
|
||||||
|
const ReadResult result = new_fields.Read(Span(storage), 0);
|
||||||
|
MaybePrint(new_fields);
|
||||||
|
EXPECT_NE(0, result.pos); // did not fail
|
||||||
|
EXPECT_NE(0, result.missing_fields);
|
||||||
|
EXPECT_EQ(0, result.extra_u32);
|
||||||
|
old_fields.CheckEqual(new_fields); // old fields are the same in both
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write old defaults, ensure new defaults remain unchanged.
|
||||||
|
TEST(FieldsTest, TestNewCodeOldDataNewUnchanged) {
|
||||||
|
OldFields old_fields;
|
||||||
|
const std::vector<uint32_t> storage = old_fields.Write();
|
||||||
|
|
||||||
|
NewFields new_fields;
|
||||||
|
const ReadResult result = new_fields.Read(Span(storage), 0);
|
||||||
|
MaybePrint(new_fields);
|
||||||
|
EXPECT_NE(0, result.pos); // did not fail
|
||||||
|
EXPECT_NE(0, result.missing_fields);
|
||||||
|
EXPECT_EQ(0, result.extra_u32);
|
||||||
|
NewFields().CheckEqual(new_fields); // new fields match their defaults
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write new defaults, read using old code.
|
||||||
|
TEST(FieldsTest, TestOldCodeNewData) {
|
||||||
|
NewFields new_fields;
|
||||||
|
const std::vector<uint32_t> storage = new_fields.Write();
|
||||||
|
|
||||||
|
OldFields old_fields;
|
||||||
|
const ReadResult result = old_fields.Read(Span(storage), 0);
|
||||||
|
MaybePrint(old_fields);
|
||||||
|
EXPECT_NE(0, result.pos); // did not fail
|
||||||
|
EXPECT_EQ(0, result.missing_fields);
|
||||||
|
EXPECT_NE(0, result.extra_u32);
|
||||||
|
EXPECT_EQ(storage.size(), result.pos + result.extra_u32);
|
||||||
|
|
||||||
|
old_fields.CheckEqual(new_fields); // old fields are the same in both
|
||||||
|
// (Can't check new fields because we only read OldFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
HWY_TEST_MAIN();
|
||||||
Loading…
Reference in New Issue