diff --git a/CMakeLists.txt b/CMakeLists.txt index f70def7..1990481 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,8 @@ set(SOURCES compression/compress.cc compression/compress.h compression/compress-inl.h + compression/fields.cc + compression/fields.h compression/io_win.cc compression/io.cc compression/io.h @@ -150,6 +152,7 @@ set(GEMMA_TEST_FILES compression/blob_store_test.cc compression/compress_test.cc compression/distortion_test.cc + compression/fields_test.cc compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index b6508cd..832f635 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -37,6 +37,26 @@ cc_library( ] + FILE_DEPS, ) +cc_library( + name = "fields", + srcs = ["fields.cc"], + hdrs = ["fields.h"], + deps = [ + "@highway//:hwy", + ], +) + +cc_test( + name = "fields_test", + srcs = ["fields_test.cc"], + deps = [ + ":fields", + "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", + "@highway//:hwy_test_util", + ], +) + cc_library( name = "blob_store", srcs = ["blob_store.cc"], diff --git a/compression/fields.cc b/compression/fields.cc new file mode 100644 index 0000000..de90ec8 --- /dev/null +++ b/compression/fields.cc @@ -0,0 +1,320 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compression/fields.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +namespace gcpp { + +IFieldsVisitor::~IFieldsVisitor() = default; + +void IFieldsVisitor::NotifyInvalid(const char* fmt, ...) { + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + any_invalid_ = true; +} + +class VisitorBase : public IFieldsVisitor { + public: + VisitorBase() = default; + ~VisitorBase() override = default; + + // This is the only call site of IFields::VisitFields. + void operator()(IFields& fields) override { fields.VisitFields(*this); } + + protected: // Functions shared between ReadVisitor and WriteVisitor: + void CheckF32(float value) { + if (HWY_UNLIKELY(hwy::ScalarIsInf(value) || hwy::ScalarIsNaN(value))) { + NotifyInvalid("Invalid float %g\n", value); + } + } + + // Return bool to avoid having to check AnyInvalid() after calling. + bool CheckStringLength(uint32_t num_u32) { + // Disallow long strings for safety, and to prevent them being used for + // arbitrary data (we also require them to be ASCII). + if (HWY_UNLIKELY(num_u32 > 64)) { + NotifyInvalid("String num_u32=%u too large\n", num_u32); + return false; + } + return true; + } + + bool CheckStringU32(uint32_t u32, uint32_t i, uint32_t num_u32) { + // Although strings are zero-padded to u32, an entire u32 should not be + // zero, and upper bits should not be set (ASCII-only). + if (HWY_UNLIKELY(u32 == 0 || (u32 & 0x80808080))) { + NotifyInvalid("Invalid characters %x at %u of %u\n", u32, i, num_u32); + return false; + } + return true; + } +}; + +class PrintVisitor : public VisitorBase { + public: + void operator()(uint32_t& value) override { + fprintf(stderr, "%sU32 %u\n", indent_.c_str(), value); + } + + void operator()(float& value) override { + fprintf(stderr, "%sF32 %f\n", indent_.c_str(), value); + } + + void operator()(std::string& value) override { + fprintf(stderr, "%sStr %s\n", indent_.c_str(), value.c_str()); + } + + void operator()(IFields& fields) override { + fprintf(stderr, "%s%s\n", indent_.c_str(), fields.Name()); + indent_ += " "; + + VisitorBase::operator()(fields); + + HWY_ASSERT(!indent_.empty()); + indent_.resize(indent_.size() - 2); + } + + private: + std::string indent_; +}; + +class ReadVisitor : public VisitorBase { + public: + ReadVisitor(const hwy::Span& span, size_t pos) + : span_(span), result_(pos) {} + ~ReadVisitor() { + HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced. + } + + // All data is read through this overload. + void operator()(uint32_t& value) override { + if (HWY_UNLIKELY(SkipField())) return; + + value = span_[result_.pos++]; + } + + void operator()(float& value) override { + if (HWY_UNLIKELY(SkipField())) return; + + uint32_t u32 = hwy::BitCastScalar(value); + operator()(u32); + value = hwy::BitCastScalar(u32); + CheckF32(value); + } + + void operator()(std::string& value) override { + if (HWY_UNLIKELY(SkipField())) return; + + uint32_t num_u32; // not including itself because this.. + operator()(num_u32); // increments result_.pos for the num_u32 field + if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return; + + // Ensure we have that much data. + if (HWY_UNLIKELY(result_.pos + num_u32 > end_.back())) { + NotifyInvalid("Invalid string: pos %zu + num_u32 %u > end %zu\n", + result_.pos, num_u32, span_.size()); + return; + } + + constexpr size_t k4 = sizeof(uint32_t); + value.resize(num_u32 * k4); + for (uint32_t i = 0; i < num_u32; ++i) { + uint32_t u32; + operator()(u32); + (void)CheckStringU32(u32, i, num_u32); + hwy::CopyBytes(&u32, value.data() + i * k4, k4); + } + + // Trim 0..3 trailing nulls. + const size_t pos = value.find_last_not_of('\0'); + if (pos != std::string::npos) { + value.resize(pos + 1); + } + } + + void operator()(IFields& fields) override { + // Our SkipField requires end_ to be set before reading num_u32, which + // determines the actual end, so use an upper bound which is tight if this + // IFields is last one in span_. + end_.push_back(span_.size()); + + if (HWY_UNLIKELY(SkipField())) { + end_.pop_back(); // undo `push_back` to keep the stack balanced + return; + } + + uint32_t num_u32; // not including itself because this.. + operator()(num_u32); // increments result_.pos for the num_u32 field + + // Ensure we have that much data and set end_. + if (HWY_UNLIKELY(result_.pos + num_u32 > span_.size())) { + NotifyInvalid("Invalid IFields: pos %zu + num_u32 %u > size %zu\n", + result_.pos, num_u32, span_.size()); + return; + } + end_.back() = result_.pos + num_u32; + + VisitorBase::operator()(fields); + + HWY_ASSERT(!end_.empty() && result_.pos <= end_.back()); + // Count extra, which indicates old code and new data. + result_.extra_u32 += end_.back() - result_.pos; + end_.pop_back(); + } + + // Override because ReadVisitor also does bounds checking. + bool SkipField() override { + // If invalid, all bets are off and we don't count missing fields. + if (HWY_UNLIKELY(AnyInvalid())) return true; + + // Reaching the end of the stored size, or the span, is not invalid - + // it happens when we read old data with new code. + if (HWY_UNLIKELY(result_.pos >= end_.back())) { + result_.missing_fields++; + return true; + } + + return false; + } + + // Override so that operator()(std::vector&) resizes the vector. + bool IsReading() const override { return true; } + + IFields::ReadResult Result() { + if (HWY_UNLIKELY(AnyInvalid())) result_.pos = 0; + return result_; + } + + private: + const hwy::Span span_; + IFields::ReadResult result_; + // Stack of end positions of nested IFields. Updated in operator()(IFields&), + // but read in SkipField. + std::vector end_; +}; + +class WriteVisitor : public VisitorBase { + public: + WriteVisitor(std::vector& storage) : storage_(storage) {} + + // Note: while writing, only string lengths/characters can trigger AnyInvalid, + // so we don't have to check SkipField. + + void operator()(uint32_t& value) override { storage_.push_back(value); } + + void operator()(float& value) override { + storage_.push_back(hwy::BitCastScalar(value)); + CheckF32(value); + } + + void operator()(std::string& value) override { + constexpr size_t k4 = sizeof(uint32_t); + + // Write length. + uint32_t num_u32 = hwy::DivCeil(value.size(), k4); + if (HWY_UNLIKELY(!CheckStringLength(num_u32))) return; + operator()(num_u32); // always valid + + // Copy whole uint32_t. + const size_t num_whole_u32 = value.size() / k4; + for (uint32_t i = 0; i < num_whole_u32; ++i) { + uint32_t u32 = 0; + hwy::CopyBytes(value.data() + i * k4, &u32, k4); + if (HWY_UNLIKELY(!CheckStringU32(u32, i, num_u32))) return; + storage_.push_back(u32); + } + + // Read remaining bytes into least-significant bits of u32. + const size_t remainder = value.size() - num_whole_u32 * k4; + if (remainder != 0) { + HWY_DASSERT(remainder < k4); + uint32_t u32 = 0; + for (size_t i = 0; i < remainder; ++i) { + const char c = value[num_whole_u32 * k4 + i]; + const uint32_t next = static_cast(static_cast(c)); + u32 += next << (i * 8); + } + if (HWY_UNLIKELY(!CheckStringU32(u32, num_whole_u32, num_u32))) return; + storage_.push_back(u32); + } + } + + void operator()(IFields& fields) override { + const size_t pos_before_size = storage_.size(); + storage_.push_back(0); // placeholder, updated below + + VisitorBase::operator()(fields); + + HWY_ASSERT(storage_[pos_before_size] == 0); + // Number of u32 written, including the one storing that number. + const uint32_t num_u32 = storage_.size() - pos_before_size; + HWY_ASSERT(num_u32 != 0); // at least one due to push_back above + // Store the payload size, not including the num_u32 field itself, because + // that is more convenient for ReadVisitor. + storage_[pos_before_size] = num_u32 - 1; + } + + private: + std::vector& storage_; +}; + +IFields::~IFields() = default; + +void IFields::Print() const { + PrintVisitor visitor; + // VisitFields is non-const. It is safe to cast because PrintVisitor does not + // modify the fields. + visitor(*const_cast(this)); +} + +IFields::ReadResult IFields::Read(const hwy::Span& span, + size_t pos) { + ReadVisitor visitor(span, pos); + visitor(*this); + return visitor.Result(); +} + +bool IFields::AppendTo(std::vector& storage) const { + // VisitFields is non-const. It is safe to cast because WriteVisitor does not + // modify the fields. + IFields& fields = *const_cast(this); + + // Reduce allocations, but not in debug builds so we notice any iterator + // invalidation bugs. + if constexpr (!HWY_IS_DEBUG_BUILD) { + storage.reserve(storage.size() + 256); + } + + WriteVisitor visitor(storage); + visitor(fields); + return !visitor.AnyInvalid(); +} + +} // namespace gcpp diff --git a/compression/fields.h b/compression/fields.h new file mode 100644 index 0000000..2ee409a --- /dev/null +++ b/compression/fields.h @@ -0,0 +1,200 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_ + +// Simple serialization/deserialization for user-defined classes, inspired by +// BSD-licensed code Copyright (c) the JPEG XL Project Authors: +// https://github.com/libjxl/libjxl, lib/jxl/fields.h. + +// IWYU pragma: begin_exports +#include +#include + +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +// IWYU pragma: end_exports + +namespace gcpp { + +// Design goals: +// - self-contained to simplify installing/building, no separate compiler (rules +// out Protocol Buffers, FlatBuffers, Cap'n Proto, Apache Thrift). +// - simplicity: small codebase without JIT (rules out Apache Fury and bitsery). +// - old code can read new data, and new code can read old data (rules out yas +// and msgpack). This avoids rewriting weights when we add a new field. +// - no user-specified versions (rules out cereal) nor field names (rules out +// JSON and GGUF). These are error-prone; users should just be able to append +// new fields. +// +// Non-goals: +// - anything better than reasonable encoding size and decode speed: we only +// anticipate ~KiB of data, alongside ~GiB of separately compressed weights. +// - features such as maps, interfaces, random access, and optional/deleted +// fields: not required for the intended use case of `ModelConfig`. +// - support any other languages than C++ and Python (for the exporter). + +class IFields; // breaks circular dependency + +// Visitors are internal-only, but their base class is visible to user code +// because their `IFields::VisitFields` calls `visitor.operator()`. +// +// Supported field types `T`: `uint32_t`, `float`, `std::string`, classes +// derived from `IFields`, `bool`, `enum`, `std::vector`. +class IFieldsVisitor { + public: + virtual ~IFieldsVisitor(); + + // Indicates whether NotifyInvalid was called for any field. Once set, this is + // sticky for all IFields visited by this visitor. + bool AnyInvalid() const { return any_invalid_; } + + // None of these fail directly, but they call NotifyInvalid() if any value + // is out of range. A single generic/overloaded function is required to + // support `std::vector`. + virtual void operator()(uint32_t& value) = 0; + virtual void operator()(float& value) = 0; + virtual void operator()(std::string& value) = 0; + virtual void operator()(IFields& fields) = 0; // recurse into nested fields + + // bool and enum fields are actually stored as uint32_t. + void operator()(bool& value) { + if (HWY_UNLIKELY(SkipField())) return; + + uint32_t u32 = value ? 1 : 0; + operator()(u32); + if (HWY_UNLIKELY(u32 > 1)) { + return NotifyInvalid("Invalid bool %u\n", u32); + } + value = (u32 == 1); + } + + template >* = nullptr> + void operator()(EnumT& value) { + if (HWY_UNLIKELY(SkipField())) return; + + uint32_t u32 = static_cast(value); + operator()(u32); + if (HWY_UNLIKELY(!EnumValid(static_cast(u32)))) { + return NotifyInvalid("Invalid enum %u\n"); + } + value = static_cast(u32); + } + + template + void operator()(std::vector& value) { + if (HWY_UNLIKELY(SkipField())) return; + + uint32_t num = static_cast(value.size()); + operator()(num); + if (HWY_UNLIKELY(num > 64 * 1024)) { + return NotifyInvalid("Vector too long %u\n", num); + } + + if (IsReading()) { + value.resize(num); + } + for (size_t i = 0; i < value.size(); ++i) { + operator()(value[i]); + } + } + + protected: + // Prints a message and causes subsequent AnyInvalid() to return true. + void NotifyInvalid(const char* fmt, ...); + + // Must check this before modifying any field, and if it returns true, + // avoid doing so. This is important for strings and vectors in the + // "new code, old data" test: resizing them may destroy their contents. + virtual bool SkipField() { return AnyInvalid(); } + // For operator()(std::vector&). + virtual bool IsReading() const { return false; } + + private: + bool any_invalid_ = false; +}; + +// Abstract base class for user-defined serializable classes, which are +// forward- and backward compatible collection of fields (members). This means +// old code can safely read new data, and new code can still handle old data. +// +// Fields are written in the unchanging order established by the user-defined +// `VisitFields`; any new fields must be visited after all existing fields in +// the same `IFields`. We encode each into `uint32_t` storage for simplicity. +// +// HOWTO: +// - basic usage: define a struct with member variables ("fields") and their +// initializers, e.g. `uint32_t field = 0;`. Then define a +// `void VisitFields(IFieldsVisitor& v)` member function that calls +// `v(field);` etc. for each field, and a `const char* Name()` function used +// as a caption when printing. +// +// - enum fields: define `enum class EnumT` and `bool EnumValid(EnumT)`, then +// call `v(field);` as usual. Note that `EnumT` is not extendable insofar as +// `EnumValid` returns false for values beyond the initially known ones. You +// can add placeholders, which requires user code to know how to handle them, +// or later add new fields including enums to override the first enum. +struct IFields { + virtual ~IFields(); + + // User-defined caption used during Print(). + virtual const char* Name() const = 0; + + // User-defined, called by IFieldsVisitor::operator()(IFields&). + virtual void VisitFields(IFieldsVisitor& visitor) = 0; + + // Prints name and fields to stderr. + void Print() const; + + struct ReadResult { + ReadResult(size_t pos) : pos(pos), missing_fields(0), extra_u32(0) {} + + // Where to resume reading in the next Read() call, or 0 if there was an + // unrecoverable error: any field has an invalid value, or the span is + // shorter than the data says it should be. If so, do not use the fields nor + // continue reading. + size_t pos; + // From the perspective of VisitFields, how many more fields would have + // been read beyond the stored size. If non-zero, the data is older than + // the code, but valid, and extra_u32 should be zero. + uint32_t missing_fields; + // How many extra u32 are in the stored size, vs. what we actually read as + // requested by VisitFields. If non-zero,, the data is newer than the code, + // but valid, and missing_fields should be zero. + uint32_t extra_u32; + }; + + // Reads fields starting at `span[pos]`. + ReadResult Read(const hwy::Span& span, size_t pos); + + // Returns false if there was an unrecoverable error, typically because a + // field has an invalid value. If so, `storage` is undefined. + bool AppendTo(std::vector& storage) const; + + // Convenience wrapper for AppendTo when we only write once. + std::vector Write() const { + std::vector storage; + if (!AppendTo(storage)) storage.clear(); + return storage; + } +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_FIELDS_H_ diff --git a/compression/fields_test.cc b/compression/fields_test.cc new file mode 100644 index 0000000..d6c1c51 --- /dev/null +++ b/compression/fields_test.cc @@ -0,0 +1,354 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compression/fields.h" + +#include +#include +#include + +#include +#include + +#include "hwy/tests/hwy_gtest.h" + +namespace gcpp { +namespace { + +#if !HWY_TEST_STANDALONE +class FieldsTest : public testing::Test {}; +#endif + +void MaybePrint(const IFields& fields) { + if (HWY_IS_DEBUG_BUILD) { + fields.Print(); + } +} + +template +void CheckVectorEqual(const std::vector& a, const std::vector& b) { + EXPECT_EQ(a.size(), b.size()); + for (size_t i = 0; i < a.size(); ++i) { + if constexpr (std::is_base_of_v) { + a[i].CheckEqual(b[i]); + } else { + EXPECT_EQ(a[i], b[i]); + } + } +} + +enum class Enum : uint32_t { + k1 = 1, + k3 = 3, + k8 = 8, +}; +HWY_MAYBE_UNUSED bool EnumValid(Enum e) { + return e == Enum::k1 || e == Enum::k3 || e == Enum::k8; +} + +// Contains all supported types except IFields and std::vector. +struct Nested : public IFields { + Nested() : nested_u32(0) {} // for std::vector + explicit Nested(uint32_t u32) : nested_u32(u32) {} + + const char* Name() const override { return "Nested"; } + void VisitFields(IFieldsVisitor& visitor) override { + visitor(nested_u32); + visitor(nested_bool); + visitor(nested_vector); + visitor(nested_enum); + visitor(nested_str); + visitor(nested_f); + } + + void CheckEqual(const Nested& n) const { + EXPECT_EQ(nested_u32, n.nested_u32); + EXPECT_EQ(nested_bool, n.nested_bool); + CheckVectorEqual(nested_vector, n.nested_vector); + EXPECT_EQ(nested_enum, n.nested_enum); + EXPECT_EQ(nested_str, n.nested_str); + EXPECT_EQ(nested_f, n.nested_f); + } + + uint32_t nested_u32; // set in ctor + bool nested_bool = true; + std::vector nested_vector = {1, 2, 3}; + Enum nested_enum = Enum::k1; + std::string nested_str = "nested"; + float nested_f = 1.125f; +}; + +// Contains all supported types. +struct OldFields : public IFields { + const char* Name() const override { return "OldFields"; } + void VisitFields(IFieldsVisitor& visitor) override { + visitor(old_str); + visitor(old_nested); + visitor(old1); + visitor(old_vec_str); + visitor(old_vec_nested); + visitor(old_f); + visitor(old_enum); + visitor(old_bool); + } + + // Template allows comparing with NewFields. + template + void CheckEqual(const Other& n) const { + EXPECT_EQ(old_str, n.old_str); + old_nested.CheckEqual(n.old_nested); + EXPECT_EQ(old1, n.old1); + CheckVectorEqual(old_vec_str, n.old_vec_str); + CheckVectorEqual(old_vec_nested, n.old_vec_nested); + EXPECT_EQ(old_f, n.old_f); + EXPECT_EQ(old_enum, n.old_enum); + EXPECT_EQ(old_bool, n.old_bool); + } + + std::string old_str = "old"; + Nested old_nested = Nested(0); + uint32_t old1 = 1; + std::vector old_vec_str = {"abc", "1234"}; + std::vector old_vec_nested = {Nested(1), Nested(4)}; + float old_f = 1.125f; + Enum old_enum = Enum::k1; + bool old_bool = true; +}; // OldFields + +// Simulates adding new fields of all types to an existing struct. +struct NewFields : public IFields { + const char* Name() const override { return "NewFields"; } + void VisitFields(IFieldsVisitor& visitor) override { + visitor(old_str); + visitor(old_nested); + visitor(old1); + visitor(old_vec_str); + visitor(old_vec_nested); + visitor(old_f); + visitor(old_enum); + visitor(old_bool); + + // Change order of field types relative to OldFields to ensure that works. + visitor(new_nested); + visitor(new_bool); + visitor(new_vec_nested); + visitor(new_f); + visitor(new_vec_str); + visitor(new_enum); + visitor(new2); + visitor(new_str); + } + + void CheckEqual(const NewFields& n) const { + EXPECT_EQ(old_str, n.old_str); + old_nested.CheckEqual(n.old_nested); + EXPECT_EQ(old1, n.old1); + CheckVectorEqual(old_vec_str, n.old_vec_str); + CheckVectorEqual(old_vec_nested, n.old_vec_nested); + EXPECT_EQ(old_f, n.old_f); + EXPECT_EQ(old_enum, n.old_enum); + EXPECT_EQ(old_bool, n.old_bool); + + new_nested.CheckEqual(n.new_nested); + EXPECT_EQ(new_bool, n.new_bool); + CheckVectorEqual(new_vec_nested, n.new_vec_nested); + EXPECT_EQ(new_f, n.new_f); + CheckVectorEqual(new_vec_str, n.new_vec_str); + EXPECT_EQ(new_enum, n.new_enum); + EXPECT_EQ(new2, n.new2); + EXPECT_EQ(new_str, n.new_str); + } + + // Copied from OldFields to match the use case of adding new fields. If we + // write an OldFields member, that would change the layout due to its size. + std::string old_str = "old"; + Nested old_nested = Nested(0); + uint32_t old1 = 1; + std::vector old_vec_str = {"abc", "1234"}; + std::vector old_vec_nested = {Nested(1), Nested(4)}; + float old_f = 1.125f; + Enum old_enum = Enum::k1; + bool old_bool = true; + + Nested new_nested = Nested(999); + bool new_bool = false; + std::vector new_vec_nested = {Nested(2), Nested(3)}; + float new_f = -2.0f; + std::vector new_vec_str = {"AB", std::string(), "56789"}; + Enum new_enum = Enum::k3; + uint32_t new2 = 2; + std::string new_str = std::string(); // empty is allowed +}; // NewFields + +// Changes all fields to non-default values. +NewFields ModifiedNewFields() { + NewFields n; + n.old_str = "old2"; + n.old_nested = Nested(5); + n.old1 = 11; + n.old_vec_str = {"abc2", "431", "ZZ"}; + n.old_vec_nested = {Nested(9)}; + n.old_f = -2.5f; + n.old_enum = Enum::k3; + n.old_bool = false; + + n.new_nested = Nested(55); + n.new_bool = true; + n.new_vec_nested = {Nested(3), Nested(33), Nested(333)}; + n.new_f = 4.f; + n.new_vec_str = {"4321", "321", "21", "1"}; + n.new_enum = Enum::k8; + n.new2 = 22; + n.new_str = "new and even longer"; + + return n; +} + +using Span = hwy::Span; + +using ReadResult = IFields::ReadResult; +void CheckConsumedAll(const ReadResult& result, size_t storage_size) { + EXPECT_NE(0, storage_size); // Ensure we notice failure (pos == 0). + EXPECT_EQ(storage_size, result.pos); + EXPECT_EQ(0, result.missing_fields); + EXPECT_EQ(0, result.extra_u32); +} + +// If we do not change any fields, Write+Read returns the defaults. +TEST(FieldsTest, TestNewMatchesDefaults) { + NewFields new_fields; + const std::vector storage = new_fields.Write(); + + const ReadResult result = new_fields.Read(Span(storage), 0); + CheckConsumedAll(result, storage.size()); + + NewFields().CheckEqual(new_fields); +} + +// Change fields from default and check that Write+Read returns them again. +TEST(FieldsTest, TestRoundTrip) { + NewFields new_fields = ModifiedNewFields(); + const std::vector storage = new_fields.Write(); + + NewFields copy; + const ReadResult result = copy.Read(Span(storage), 0); + CheckConsumedAll(result, storage.size()); + + new_fields.CheckEqual(copy); +} + +// Refuse to write invalid floats. +TEST(FieldsTest, TestInvalidFloat) { + NewFields new_fields; + new_fields.new_f = std::numeric_limits::infinity(); + EXPECT_TRUE(new_fields.Write().empty()); + + new_fields.new_f = std::numeric_limits::quiet_NaN(); + EXPECT_TRUE(new_fields.Write().empty()); +} + +// Refuse to write invalid strings. +TEST(FieldsTest, TestInvalidString) { + NewFields new_fields; + // Four zero bytes + new_fields.new_str.assign(4, '\0'); + EXPECT_TRUE(new_fields.Write().empty()); + + // Too long + new_fields.new_str.assign(257, 'a'); + EXPECT_TRUE(new_fields.Write().empty()); + + // First byte not ASCII + new_fields.new_str.assign("123"); + new_fields.new_str[0] = 128; + EXPECT_TRUE(new_fields.Write().empty()); + + // Upper byte in later u32 not ASCII + new_fields.new_str.assign("ABCDEFGH"); + new_fields.new_str[7] = 255; + EXPECT_TRUE(new_fields.Write().empty()); +} + +// Write two structs to the same storage. +TEST(FieldsTest, TestMultipleWrite) { + const NewFields modified = ModifiedNewFields(); + std::vector storage = modified.Write(); + const size_t modified_size = storage.size(); + const NewFields defaults; + defaults.AppendTo(storage); + + // Start with defaults to ensure Read retrieves the modified values. + NewFields modified_copy; + const ReadResult result1 = modified_copy.Read(Span(storage), 0); + CheckConsumedAll(result1, modified_size); + modified.CheckEqual(modified_copy); + + // Start with modified values to ensure Read retrieves the defaults. + NewFields defaults_copy = modified; + const ReadResult result2 = defaults_copy.Read(Span(storage), result1.pos); + CheckConsumedAll(result2, storage.size()); + defaults.CheckEqual(defaults_copy); +} + +// Write old defaults, read old using new code. +TEST(FieldsTest, TestNewCodeOldData) { + OldFields old_fields; + const std::vector storage = old_fields.Write(); + + // Start with modified old values to ensure old defaults overwrite them. + NewFields new_fields = ModifiedNewFields(); + const ReadResult result = new_fields.Read(Span(storage), 0); + MaybePrint(new_fields); + EXPECT_NE(0, result.pos); // did not fail + EXPECT_NE(0, result.missing_fields); + EXPECT_EQ(0, result.extra_u32); + old_fields.CheckEqual(new_fields); // old fields are the same in both +} + +// Write old defaults, ensure new defaults remain unchanged. +TEST(FieldsTest, TestNewCodeOldDataNewUnchanged) { + OldFields old_fields; + const std::vector storage = old_fields.Write(); + + NewFields new_fields; + const ReadResult result = new_fields.Read(Span(storage), 0); + MaybePrint(new_fields); + EXPECT_NE(0, result.pos); // did not fail + EXPECT_NE(0, result.missing_fields); + EXPECT_EQ(0, result.extra_u32); + NewFields().CheckEqual(new_fields); // new fields match their defaults +} + +// Write new defaults, read using old code. +TEST(FieldsTest, TestOldCodeNewData) { + NewFields new_fields; + const std::vector storage = new_fields.Write(); + + OldFields old_fields; + const ReadResult result = old_fields.Read(Span(storage), 0); + MaybePrint(old_fields); + EXPECT_NE(0, result.pos); // did not fail + EXPECT_EQ(0, result.missing_fields); + EXPECT_NE(0, result.extra_u32); + EXPECT_EQ(storage.size(), result.pos + result.extra_u32); + + old_fields.CheckEqual(new_fields); // old fields are the same in both + // (Can't check new fields because we only read OldFields) +} + +} // namespace +} // namespace gcpp + +HWY_TEST_MAIN();