Serialization for class members for use with ModelConfig

PiperOrigin-RevId: 689720027
This commit is contained in:
Jan Wassenberg 2024-10-25 03:11:55 -07:00 committed by Copybara-Service
parent efff64605a
commit 52af531820
5 changed files with 897 additions and 0 deletions

View File

@ -45,6 +45,8 @@ set(SOURCES
compression/compress.cc compression/compress.cc
compression/compress.h compression/compress.h
compression/compress-inl.h compression/compress-inl.h
compression/fields.cc
compression/fields.h
compression/io_win.cc compression/io_win.cc
compression/io.cc compression/io.cc
compression/io.h compression/io.h
@ -150,6 +152,7 @@ set(GEMMA_TEST_FILES
compression/blob_store_test.cc compression/blob_store_test.cc
compression/compress_test.cc compression/compress_test.cc
compression/distortion_test.cc compression/distortion_test.cc
compression/fields_test.cc
compression/nuq_test.cc compression/nuq_test.cc
compression/sfp_test.cc compression/sfp_test.cc
evals/gemma_test.cc evals/gemma_test.cc

View File

@ -37,6 +37,26 @@ cc_library(
] + FILE_DEPS, ] + FILE_DEPS,
) )
cc_library(
name = "fields",
srcs = ["fields.cc"],
hdrs = ["fields.h"],
deps = [
"@highway//:hwy",
],
)
cc_test(
name = "fields_test",
srcs = ["fields_test.cc"],
deps = [
":fields",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
cc_library( cc_library(
name = "blob_store", name = "blob_store",
srcs = ["blob_store.cc"], srcs = ["blob_store.cc"],

320
compression/fields.cc Normal file
View File

@ -0,0 +1,320 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/fields.h"
#include <stdarg.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <string>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
namespace gcpp {
IFieldsVisitor::~IFieldsVisitor() = default;
void IFieldsVisitor::NotifyInvalid(const char* fmt, ...) {
va_list args;
va_start(args, fmt);
vfprintf(stderr, fmt, args);
va_end(args);
any_invalid_ = true;
}
class VisitorBase : public IFieldsVisitor {
public:
VisitorBase() = default;
~VisitorBase() override = default;
// This is the only call site of IFields::VisitFields.
void operator()(IFields& fields) override { fields.VisitFields(*this); }
protected: // Functions shared between ReadVisitor and WriteVisitor:
void CheckF32(float value) {
if (HWY_UNLIKELY(hwy::ScalarIsInf(value) || hwy::ScalarIsNaN(value))) {
NotifyInvalid("Invalid float %g\n", value);
}
}
// Return bool to avoid having to check AnyInvalid() after calling.
bool CheckStringLength(uint32_t num_u32) {
// Disallow long strings for safety, and to prevent them being used for
// arbitrary data (we also require them to be ASCII).
if (HWY_UNLIKELY(num_u32 > 64)) {
NotifyInvalid("String num_u32=%u too large\n", num_u32);
return false;
}
return true;
}
bool CheckStringU32(uint32_t u32, uint32_t i, uint32_t num_u32) {
// Although strings are zero-padded to u32, an entire u32 should not be
// zero, and upper bits should not be set (ASCII-only).
if (HWY_UNLIKELY(u32 == 0 || (u32 & 0x80808080))) {
NotifyInvalid("Invalid characters %x at %u of %u\n", u32, i, num_u32);
return false;
}
return true;
}
};
class PrintVisitor : public VisitorBase {
public:
void operator()(uint32_t& value) override {
fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value);
}
void operator()(float& value) override {
fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value);
}
void operator()(std::string& value) override {
fprintf(stderr, "%sStr %s\n", indent_.c_str(), value.c_str());
}
void operator()(IFields& fields) override {
fprintf(stderr, "%s%s\n", indent_.c_str(), fields.Name());
indent_ += " ";
VisitorBase::operator()(fields);
HWY_ASSERT(!indent_.empty());
indent_.resize(indent_.size() - 2);
}
private:
std::string indent_;
};
class ReadVisitor : public VisitorBase {
public:
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos)
: span_(span), result_(pos) {}
~ReadVisitor() {
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
}
// All data is read through this overload.
void operator()(uint32_t& value) override {
if (HWY_UNLIKELY(SkipField())) return;
value = span_[result_.pos++];
}
void operator()(float& value) override {
if (HWY_UNLIKELY(SkipField())) return;
uint32_t u32 = hwy::BitCastScalar<uint32_t>(value);
operator()(u32);
value = hwy::BitCastScalar<float>(u32);
CheckF32(value);
}
void operator()(std::string& value) override {
if (HWY_UNLIKELY(SkipField())) return;
uint32_t num_u32; // not including itself because this..
operator()(num_u32); // increments result_.pos for the num_u32 field
if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return;
// Ensure we have that much data.
if (HWY_UNLIKELY(result_.pos + num_u32 > end_.back())) {
NotifyInvalid("Invalid string: pos %zu + num_u32 %u > end %zu\n",
result_.pos, num_u32, span_.size());
return;
}
constexpr size_t k4 = sizeof(uint32_t);
value.resize(num_u32 * k4);
for (uint32_t i = 0; i < num_u32; ++i) {
uint32_t u32;
operator()(u32);
(void)CheckStringU32(u32, i, num_u32);
hwy::CopyBytes(&u32, value.data() + i * k4, k4);
}
// Trim 0..3 trailing nulls.
const size_t pos = value.find_last_not_of('\0');
if (pos != std::string::npos) {
value.resize(pos + 1);
}
}
void operator()(IFields& fields) override {
// Our SkipField requires end_ to be set before reading num_u32, which
// determines the actual end, so use an upper bound which is tight if this
// IFields is last one in span_.
end_.push_back(span_.size());
if (HWY_UNLIKELY(SkipField())) {
end_.pop_back(); // undo `push_back` to keep the stack balanced
return;
}
uint32_t num_u32; // not including itself because this..
operator()(num_u32); // increments result_.pos for the num_u32 field
// Ensure we have that much data and set end_.
if (HWY_UNLIKELY(result_.pos + num_u32 > span_.size())) {
NotifyInvalid("Invalid IFields: pos %zu + num_u32 %u > size %zu\n",
result_.pos, num_u32, span_.size());
return;
}
end_.back() = result_.pos + num_u32;
VisitorBase::operator()(fields);
HWY_ASSERT(!end_.empty() && result_.pos <= end_.back());
// Count extra, which indicates old code and new data.
result_.extra_u32 += end_.back() - result_.pos;
end_.pop_back();
}
// Override because ReadVisitor also does bounds checking.
bool SkipField() override {
// If invalid, all bets are off and we don't count missing fields.
if (HWY_UNLIKELY(AnyInvalid())) return true;
// Reaching the end of the stored size, or the span, is not invalid -
// it happens when we read old data with new code.
if (HWY_UNLIKELY(result_.pos >= end_.back())) {
result_.missing_fields++;
return true;
}
return false;
}
// Override so that operator()(std::vector<T>&) resizes the vector.
bool IsReading() const override { return true; }
IFields::ReadResult Result() {
if (HWY_UNLIKELY(AnyInvalid())) result_.pos = 0;
return result_;
}
private:
const hwy::Span<const uint32_t> span_;
IFields::ReadResult result_;
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
// but read in SkipField.
std::vector<uint32_t> end_;
};
class WriteVisitor : public VisitorBase {
public:
WriteVisitor(std::vector<uint32_t>& storage) : storage_(storage) {}
// Note: while writing, only string lengths/characters can trigger AnyInvalid,
// so we don't have to check SkipField.
void operator()(uint32_t& value) override { storage_.push_back(value); }
void operator()(float& value) override {
storage_.push_back(hwy::BitCastScalar<uint32_t>(value));
CheckF32(value);
}
void operator()(std::string& value) override {
constexpr size_t k4 = sizeof(uint32_t);
// Write length.
uint32_t num_u32 = hwy::DivCeil(value.size(), k4);
if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return;
operator()(num_u32); // always valid
// Copy whole uint32_t.
const size_t num_whole_u32 = value.size() / k4;
for (uint32_t i = 0; i < num_whole_u32; ++i) {
uint32_t u32 = 0;
hwy::CopyBytes(value.data() + i * k4, &u32, k4);
if (HWY_UNLIKELY(!CheckStringU32(u32, i, num_u32))) return;
storage_.push_back(u32);
}
// Read remaining bytes into least-significant bits of u32.
const size_t remainder = value.size() - num_whole_u32 * k4;
if (remainder != 0) {
HWY_DASSERT(remainder < k4);
uint32_t u32 = 0;
for (size_t i = 0; i < remainder; ++i) {
const char c = value[num_whole_u32 * k4 + i];
const uint32_t next = static_cast<uint32_t>(static_cast<uint8_t>(c));
u32 += next << (i * 8);
}
if (HWY_UNLIKELY(!CheckStringU32(u32, num_whole_u32, num_u32))) return;
storage_.push_back(u32);
}
}
void operator()(IFields& fields) override {
const size_t pos_before_size = storage_.size();
storage_.push_back(0); // placeholder, updated below
VisitorBase::operator()(fields);
HWY_ASSERT(storage_[pos_before_size] == 0);
// Number of u32 written, including the one storing that number.
const uint32_t num_u32 = storage_.size() - pos_before_size;
HWY_ASSERT(num_u32 != 0); // at least one due to push_back above
// Store the payload size, not including the num_u32 field itself, because
// that is more convenient for ReadVisitor.
storage_[pos_before_size] = num_u32 - 1;
}
private:
std::vector<uint32_t>& storage_;
};
IFields::~IFields() = default;
void IFields::Print() const {
PrintVisitor visitor;
// VisitFields is non-const. It is safe to cast because PrintVisitor does not
// modify the fields.
visitor(*const_cast<IFields*>(this));
}
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span,
size_t pos) {
ReadVisitor visitor(span, pos);
visitor(*this);
return visitor.Result();
}
bool IFields::AppendTo(std::vector<uint32_t>& storage) const {
// VisitFields is non-const. It is safe to cast because WriteVisitor does not
// modify the fields.
IFields& fields = *const_cast<IFields*>(this);
// Reduce allocations, but not in debug builds so we notice any iterator
// invalidation bugs.
if constexpr (!HWY_IS_DEBUG_BUILD) {
storage.reserve(storage.size() + 256);
}
WriteVisitor visitor(storage);
visitor(fields);
return !visitor.AnyInvalid();
}
} // namespace gcpp

200
compression/fields.h Normal file
View File

@ -0,0 +1,200 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_
// Simple serialization/deserialization for user-defined classes, inspired by
// BSD-licensed code Copyright (c) the JPEG XL Project Authors:
// https://github.com/libjxl/libjxl, lib/jxl/fields.h.
// IWYU pragma: begin_exports
#include <stddef.h>
#include <stdint.h>
#include <string>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// IWYU pragma: end_exports
namespace gcpp {
// Design goals:
// - self-contained to simplify installing/building, no separate compiler (rules
// out Protocol Buffers, FlatBuffers, Cap'n Proto, Apache Thrift).
// - simplicity: small codebase without JIT (rules out Apache Fury and bitsery).
// - old code can read new data, and new code can read old data (rules out yas
// and msgpack). This avoids rewriting weights when we add a new field.
// - no user-specified versions (rules out cereal) nor field names (rules out
// JSON and GGUF). These are error-prone; users should just be able to append
// new fields.
//
// Non-goals:
// - anything better than reasonable encoding size and decode speed: we only
// anticipate ~KiB of data, alongside ~GiB of separately compressed weights.
// - features such as maps, interfaces, random access, and optional/deleted
// fields: not required for the intended use case of `ModelConfig`.
// - support any other languages than C++ and Python (for the exporter).
class IFields; // breaks circular dependency
// Visitors are internal-only, but their base class is visible to user code
// because their `IFields::VisitFields` calls `visitor.operator()`.
//
// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes
// derived from `IFields`, `bool`, `enum`, `std::vector<T>`.
class IFieldsVisitor {
public:
virtual ~IFieldsVisitor();
// Indicates whether NotifyInvalid was called for any field. Once set, this is
// sticky for all IFields visited by this visitor.
bool AnyInvalid() const { return any_invalid_; }
// None of these fail directly, but they call NotifyInvalid() if any value
// is out of range. A single generic/overloaded function is required to
// support `std::vector<T>`.
virtual void operator()(uint32_t& value) = 0;
virtual void operator()(float& value) = 0;
virtual void operator()(std::string& value) = 0;
virtual void operator()(IFields& fields) = 0; // recurse into nested fields
// bool and enum fields are actually stored as uint32_t.
void operator()(bool& value) {
if (HWY_UNLIKELY(SkipField())) return;
uint32_t u32 = value ? 1 : 0;
operator()(u32);
if (HWY_UNLIKELY(u32 > 1)) {
return NotifyInvalid("Invalid bool %u\n", u32);
}
value = (u32 == 1);
}
template <typename EnumT, hwy::EnableIf<std::is_enum_v<EnumT>>* = nullptr>
void operator()(EnumT& value) {
if (HWY_UNLIKELY(SkipField())) return;
uint32_t u32 = static_cast<uint32_t>(value);
operator()(u32);
if (HWY_UNLIKELY(!EnumValid(static_cast<EnumT>(u32)))) {
return NotifyInvalid("Invalid enum %u\n");
}
value = static_cast<EnumT>(u32);
}
template <typename T>
void operator()(std::vector<T>& value) {
if (HWY_UNLIKELY(SkipField())) return;
uint32_t num = static_cast<uint32_t>(value.size());
operator()(num);
if (HWY_UNLIKELY(num > 64 * 1024)) {
return NotifyInvalid("Vector too long %u\n", num);
}
if (IsReading()) {
value.resize(num);
}
for (size_t i = 0; i < value.size(); ++i) {
operator()(value[i]);
}
}
protected:
// Prints a message and causes subsequent AnyInvalid() to return true.
void NotifyInvalid(const char* fmt, ...);
// Must check this before modifying any field, and if it returns true,
// avoid doing so. This is important for strings and vectors in the
// "new code, old data" test: resizing them may destroy their contents.
virtual bool SkipField() { return AnyInvalid(); }
// For operator()(std::vector&).
virtual bool IsReading() const { return false; }
private:
bool any_invalid_ = false;
};
// Abstract base class for user-defined serializable classes, which are
// forward- and backward compatible collection of fields (members). This means
// old code can safely read new data, and new code can still handle old data.
//
// Fields are written in the unchanging order established by the user-defined
// `VisitFields`; any new fields must be visited after all existing fields in
// the same `IFields`. We encode each into `uint32_t` storage for simplicity.
//
// HOWTO:
// - basic usage: define a struct with member variables ("fields") and their
// initializers, e.g. `uint32_t field = 0;`. Then define a
// `void VisitFields(IFieldsVisitor& v)` member function that calls
// `v(field);` etc. for each field, and a `const char* Name()` function used
// as a caption when printing.
//
// - enum fields: define `enum class EnumT` and `bool EnumValid(EnumT)`, then
// call `v(field);` as usual. Note that `EnumT` is not extendable insofar as
// `EnumValid` returns false for values beyond the initially known ones. You
// can add placeholders, which requires user code to know how to handle them,
// or later add new fields including enums to override the first enum.
struct IFields {
virtual ~IFields();
// User-defined caption used during Print().
virtual const char* Name() const = 0;
// User-defined, called by IFieldsVisitor::operator()(IFields&).
virtual void VisitFields(IFieldsVisitor& visitor) = 0;
// Prints name and fields to stderr.
void Print() const;
struct ReadResult {
ReadResult(size_t pos) : pos(pos), missing_fields(0), extra_u32(0) {}
// Where to resume reading in the next Read() call, or 0 if there was an
// unrecoverable error: any field has an invalid value, or the span is
// shorter than the data says it should be. If so, do not use the fields nor
// continue reading.
size_t pos;
// From the perspective of VisitFields, how many more fields would have
// been read beyond the stored size. If non-zero, the data is older than
// the code, but valid, and extra_u32 should be zero.
uint32_t missing_fields;
// How many extra u32 are in the stored size, vs. what we actually read as
// requested by VisitFields. If non-zero,, the data is newer than the code,
// but valid, and missing_fields should be zero.
uint32_t extra_u32;
};
// Reads fields starting at `span[pos]`.
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos);
// Returns false if there was an unrecoverable error, typically because a
// field has an invalid value. If so, `storage` is undefined.
bool AppendTo(std::vector<uint32_t>& storage) const;
// Convenience wrapper for AppendTo when we only write once.
std::vector<uint32_t> Write() const {
std::vector<uint32_t> storage;
if (!AppendTo(storage)) storage.clear();
return storage;
}
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_

354
compression/fields_test.cc Normal file
View File

@ -0,0 +1,354 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "compression/fields.h"
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <limits>
#include <type_traits>
#include "hwy/tests/hwy_gtest.h"
namespace gcpp {
namespace {
#if !HWY_TEST_STANDALONE
class FieldsTest : public testing::Test {};
#endif
void MaybePrint(const IFields& fields) {
if (HWY_IS_DEBUG_BUILD) {
fields.Print();
}
}
template <typename T>
void CheckVectorEqual(const std::vector<T>& a, const std::vector<T>& b) {
EXPECT_EQ(a.size(), b.size());
for (size_t i = 0; i < a.size(); ++i) {
if constexpr (std::is_base_of_v<IFields, T>) {
a[i].CheckEqual(b[i]);
} else {
EXPECT_EQ(a[i], b[i]);
}
}
}
enum class Enum : uint32_t {
k1 = 1,
k3 = 3,
k8 = 8,
};
HWY_MAYBE_UNUSED bool EnumValid(Enum e) {
return e == Enum::k1 || e == Enum::k3 || e == Enum::k8;
}
// Contains all supported types except IFields and std::vector<IFields>.
struct Nested : public IFields {
Nested() : nested_u32(0) {} // for std::vector
explicit Nested(uint32_t u32) : nested_u32(u32) {}
const char* Name() const override { return "Nested"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(nested_u32);
visitor(nested_bool);
visitor(nested_vector);
visitor(nested_enum);
visitor(nested_str);
visitor(nested_f);
}
void CheckEqual(const Nested& n) const {
EXPECT_EQ(nested_u32, n.nested_u32);
EXPECT_EQ(nested_bool, n.nested_bool);
CheckVectorEqual(nested_vector, n.nested_vector);
EXPECT_EQ(nested_enum, n.nested_enum);
EXPECT_EQ(nested_str, n.nested_str);
EXPECT_EQ(nested_f, n.nested_f);
}
uint32_t nested_u32; // set in ctor
bool nested_bool = true;
std::vector<uint32_t> nested_vector = {1, 2, 3};
Enum nested_enum = Enum::k1;
std::string nested_str = "nested";
float nested_f = 1.125f;
};
// Contains all supported types.
struct OldFields : public IFields {
const char* Name() const override { return "OldFields"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(old_str);
visitor(old_nested);
visitor(old1);
visitor(old_vec_str);
visitor(old_vec_nested);
visitor(old_f);
visitor(old_enum);
visitor(old_bool);
}
// Template allows comparing with NewFields.
template <typename Other>
void CheckEqual(const Other& n) const {
EXPECT_EQ(old_str, n.old_str);
old_nested.CheckEqual(n.old_nested);
EXPECT_EQ(old1, n.old1);
CheckVectorEqual(old_vec_str, n.old_vec_str);
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
EXPECT_EQ(old_f, n.old_f);
EXPECT_EQ(old_enum, n.old_enum);
EXPECT_EQ(old_bool, n.old_bool);
}
std::string old_str = "old";
Nested old_nested = Nested(0);
uint32_t old1 = 1;
std::vector<std::string> old_vec_str = {"abc", "1234"};
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
float old_f = 1.125f;
Enum old_enum = Enum::k1;
bool old_bool = true;
}; // OldFields
// Simulates adding new fields of all types to an existing struct.
struct NewFields : public IFields {
const char* Name() const override { return "NewFields"; }
void VisitFields(IFieldsVisitor& visitor) override {
visitor(old_str);
visitor(old_nested);
visitor(old1);
visitor(old_vec_str);
visitor(old_vec_nested);
visitor(old_f);
visitor(old_enum);
visitor(old_bool);
// Change order of field types relative to OldFields to ensure that works.
visitor(new_nested);
visitor(new_bool);
visitor(new_vec_nested);
visitor(new_f);
visitor(new_vec_str);
visitor(new_enum);
visitor(new2);
visitor(new_str);
}
void CheckEqual(const NewFields& n) const {
EXPECT_EQ(old_str, n.old_str);
old_nested.CheckEqual(n.old_nested);
EXPECT_EQ(old1, n.old1);
CheckVectorEqual(old_vec_str, n.old_vec_str);
CheckVectorEqual(old_vec_nested, n.old_vec_nested);
EXPECT_EQ(old_f, n.old_f);
EXPECT_EQ(old_enum, n.old_enum);
EXPECT_EQ(old_bool, n.old_bool);
new_nested.CheckEqual(n.new_nested);
EXPECT_EQ(new_bool, n.new_bool);
CheckVectorEqual(new_vec_nested, n.new_vec_nested);
EXPECT_EQ(new_f, n.new_f);
CheckVectorEqual(new_vec_str, n.new_vec_str);
EXPECT_EQ(new_enum, n.new_enum);
EXPECT_EQ(new2, n.new2);
EXPECT_EQ(new_str, n.new_str);
}
// Copied from OldFields to match the use case of adding new fields. If we
// write an OldFields member, that would change the layout due to its size.
std::string old_str = "old";
Nested old_nested = Nested(0);
uint32_t old1 = 1;
std::vector<std::string> old_vec_str = {"abc", "1234"};
std::vector<Nested> old_vec_nested = {Nested(1), Nested(4)};
float old_f = 1.125f;
Enum old_enum = Enum::k1;
bool old_bool = true;
Nested new_nested = Nested(999);
bool new_bool = false;
std::vector<Nested> new_vec_nested = {Nested(2), Nested(3)};
float new_f = -2.0f;
std::vector<std::string> new_vec_str = {"AB", std::string(), "56789"};
Enum new_enum = Enum::k3;
uint32_t new2 = 2;
std::string new_str = std::string(); // empty is allowed
}; // NewFields
// Changes all fields to non-default values.
NewFields ModifiedNewFields() {
NewFields n;
n.old_str = "old2";
n.old_nested = Nested(5);
n.old1 = 11;
n.old_vec_str = {"abc2", "431", "ZZ"};
n.old_vec_nested = {Nested(9)};
n.old_f = -2.5f;
n.old_enum = Enum::k3;
n.old_bool = false;
n.new_nested = Nested(55);
n.new_bool = true;
n.new_vec_nested = {Nested(3), Nested(33), Nested(333)};
n.new_f = 4.f;
n.new_vec_str = {"4321", "321", "21", "1"};
n.new_enum = Enum::k8;
n.new2 = 22;
n.new_str = "new and even longer";
return n;
}
using Span = hwy::Span<const uint32_t>;
using ReadResult = IFields::ReadResult;
void CheckConsumedAll(const ReadResult& result, size_t storage_size) {
EXPECT_NE(0, storage_size); // Ensure we notice failure (pos == 0).
EXPECT_EQ(storage_size, result.pos);
EXPECT_EQ(0, result.missing_fields);
EXPECT_EQ(0, result.extra_u32);
}
// If we do not change any fields, Write+Read returns the defaults.
TEST(FieldsTest, TestNewMatchesDefaults) {
NewFields new_fields;
const std::vector<uint32_t> storage = new_fields.Write();
const ReadResult result = new_fields.Read(Span(storage), 0);
CheckConsumedAll(result, storage.size());
NewFields().CheckEqual(new_fields);
}
// Change fields from default and check that Write+Read returns them again.
TEST(FieldsTest, TestRoundTrip) {
NewFields new_fields = ModifiedNewFields();
const std::vector<uint32_t> storage = new_fields.Write();
NewFields copy;
const ReadResult result = copy.Read(Span(storage), 0);
CheckConsumedAll(result, storage.size());
new_fields.CheckEqual(copy);
}
// Refuse to write invalid floats.
TEST(FieldsTest, TestInvalidFloat) {
NewFields new_fields;
new_fields.new_f = std::numeric_limits<float>::infinity();
EXPECT_TRUE(new_fields.Write().empty());
new_fields.new_f = std::numeric_limits<float>::quiet_NaN();
EXPECT_TRUE(new_fields.Write().empty());
}
// Refuse to write invalid strings.
TEST(FieldsTest, TestInvalidString) {
NewFields new_fields;
// Four zero bytes
new_fields.new_str.assign(4, '\0');
EXPECT_TRUE(new_fields.Write().empty());
// Too long
new_fields.new_str.assign(257, 'a');
EXPECT_TRUE(new_fields.Write().empty());
// First byte not ASCII
new_fields.new_str.assign("123");
new_fields.new_str[0] = 128;
EXPECT_TRUE(new_fields.Write().empty());
// Upper byte in later u32 not ASCII
new_fields.new_str.assign("ABCDEFGH");
new_fields.new_str[7] = 255;
EXPECT_TRUE(new_fields.Write().empty());
}
// Write two structs to the same storage.
TEST(FieldsTest, TestMultipleWrite) {
const NewFields modified = ModifiedNewFields();
std::vector<uint32_t> storage = modified.Write();
const size_t modified_size = storage.size();
const NewFields defaults;
defaults.AppendTo(storage);
// Start with defaults to ensure Read retrieves the modified values.
NewFields modified_copy;
const ReadResult result1 = modified_copy.Read(Span(storage), 0);
CheckConsumedAll(result1, modified_size);
modified.CheckEqual(modified_copy);
// Start with modified values to ensure Read retrieves the defaults.
NewFields defaults_copy = modified;
const ReadResult result2 = defaults_copy.Read(Span(storage), result1.pos);
CheckConsumedAll(result2, storage.size());
defaults.CheckEqual(defaults_copy);
}
// Write old defaults, read old using new code.
TEST(FieldsTest, TestNewCodeOldData) {
OldFields old_fields;
const std::vector<uint32_t> storage = old_fields.Write();
// Start with modified old values to ensure old defaults overwrite them.
NewFields new_fields = ModifiedNewFields();
const ReadResult result = new_fields.Read(Span(storage), 0);
MaybePrint(new_fields);
EXPECT_NE(0, result.pos); // did not fail
EXPECT_NE(0, result.missing_fields);
EXPECT_EQ(0, result.extra_u32);
old_fields.CheckEqual(new_fields); // old fields are the same in both
}
// Write old defaults, ensure new defaults remain unchanged.
TEST(FieldsTest, TestNewCodeOldDataNewUnchanged) {
OldFields old_fields;
const std::vector<uint32_t> storage = old_fields.Write();
NewFields new_fields;
const ReadResult result = new_fields.Read(Span(storage), 0);
MaybePrint(new_fields);
EXPECT_NE(0, result.pos); // did not fail
EXPECT_NE(0, result.missing_fields);
EXPECT_EQ(0, result.extra_u32);
NewFields().CheckEqual(new_fields); // new fields match their defaults
}
// Write new defaults, read using old code.
TEST(FieldsTest, TestOldCodeNewData) {
NewFields new_fields;
const std::vector<uint32_t> storage = new_fields.Write();
OldFields old_fields;
const ReadResult result = old_fields.Read(Span(storage), 0);
MaybePrint(old_fields);
EXPECT_NE(0, result.pos); // did not fail
EXPECT_EQ(0, result.missing_fields);
EXPECT_NE(0, result.extra_u32);
EXPECT_EQ(storage.size(), result.pos + result.extra_u32);
old_fields.CheckEqual(new_fields); // old fields are the same in both
// (Can't check new fields because we only read OldFields)
}
} // namespace
} // namespace gcpp
HWY_TEST_MAIN();