diff --git a/BUILD.bazel b/BUILD.bazel index 8c32631..f4aed0d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -265,10 +265,10 @@ cc_library( "gemma/tensor_index.h", ], deps = [ + ":basics", "//compression:fields", "//compression:sfp", "@highway//:hwy", # base.h - "@highway//:thread_pool", ], ) diff --git a/compression/fields.cc b/compression/fields.cc index 8977af7..092597f 100644 --- a/compression/fields.cc +++ b/compression/fields.cc @@ -24,7 +24,6 @@ #include #include -#include "hwy/aligned_allocator.h" #include "hwy/base.h" namespace gcpp { @@ -115,7 +114,7 @@ class PrintVisitor : public VisitorBase { class ReadVisitor : public VisitorBase { public: - ReadVisitor(const hwy::Span& span, size_t pos) + ReadVisitor(const SerializedSpan span, size_t pos) : span_(span), result_(pos) {} ~ReadVisitor() { HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced. @@ -236,7 +235,7 @@ class ReadVisitor : public VisitorBase { } private: - const hwy::Span span_; + const SerializedSpan span_; IFields::ReadResult result_; // Stack of end positions of nested IFields. Updated in operator()(IFields&), // but read in SkipField. @@ -326,8 +325,7 @@ void IFields::Print() const { visitor(*const_cast(this)); } -IFields::ReadResult IFields::Read(const hwy::Span& span, - size_t pos) { +IFields::ReadResult IFields::Read(const SerializedSpan span, size_t pos) { ReadVisitor visitor(span, pos); visitor(*this); return visitor.Result(); diff --git a/compression/fields.h b/compression/fields.h index a17b48c..57465c4 100644 --- a/compression/fields.h +++ b/compression/fields.h @@ -27,7 +27,7 @@ #include #include -#include "hwy/aligned_allocator.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // IWYU pragma: end_exports @@ -133,6 +133,8 @@ class IFieldsVisitor { bool any_invalid_ = false; }; +using SerializedSpan = hwy::Span; + // Abstract base class for user-defined serializable classes, which are // forward- and backward compatible collection of fields (members). This means // old code can safely read new data, and new code can still handle old data. @@ -178,13 +180,13 @@ struct IFields { // the code, but valid, and extra_u32 should be zero. uint32_t missing_fields; // How many extra u32 are in the stored size, vs. what we actually read as - // requested by VisitFields. If non-zero,, the data is newer than the code, + // requested by VisitFields. If non-zero, the data is newer than the code, // but valid, and missing_fields should be zero. uint32_t extra_u32; }; // Reads fields starting at `span[pos]`. - ReadResult Read(const hwy::Span& span, size_t pos); + ReadResult Read(SerializedSpan span, size_t pos); // Returns false if there was an unrecoverable error, typically because a // field has an invalid value. If so, `storage` is undefined. diff --git a/gemma/common.cc b/gemma/common.cc index 0d8977b..da782c5 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -24,9 +24,8 @@ #include #include -#include "compression/shared.h" +#include "util/basics.h" // BF16 #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -162,8 +161,8 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { float EmbeddingScaling(size_t model_dim) { // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - sqrtf(static_cast(model_dim)))); + return hwy::ConvertScalarTo( + hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); } float ChooseQueryScale(const ModelConfig& config) { diff --git a/gemma/common.h b/gemma/common.h index 984b0ba..8aa2112 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -20,9 +20,9 @@ #include -#include "compression/shared.h" // PromptWrapping -#include "gemma/configs.h" // IWYU pragma: export -#include "hwy/base.h" // ConvertScalarTo +#include "compression/shared.h" // Type +#include "gemma/configs.h" // IWYU pragma: export +#include "hwy/base.h" // ConvertScalarTo namespace gcpp { diff --git a/util/basics.h b/util/basics.h index b8f2735..40545fd 100644 --- a/util/basics.h +++ b/util/basics.h @@ -67,10 +67,7 @@ struct TokenAndProb { // Entire size of a 2D array. struct Extents2D { constexpr Extents2D() : rows(0), cols(0) {} - constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { - HWY_DASSERT(rows != 0); - HWY_DASSERT(cols != 0); - } + constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {} size_t Area() const { return rows * cols; }