Minor cleanup: enable 0,0 Extents2D, add SerializedSpan typedef, include fixes

PiperOrigin-RevId: 745068776
This commit is contained in:
Jan Wassenberg 2025-04-08 03:35:08 -07:00 committed by Copybara-Service
parent 76a81ac2d6
commit 4e6aa36e9b
6 changed files with 16 additions and 20 deletions

View File

@ -265,10 +265,10 @@ cc_library(
"gemma/tensor_index.h", "gemma/tensor_index.h",
], ],
deps = [ deps = [
":basics",
"//compression:fields", "//compression:fields",
"//compression:sfp", "//compression:sfp",
"@highway//:hwy", # base.h "@highway//:hwy", # base.h
"@highway//:thread_pool",
], ],
) )

View File

@ -24,7 +24,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
namespace gcpp { namespace gcpp {
@ -115,7 +114,7 @@ class PrintVisitor : public VisitorBase {
class ReadVisitor : public VisitorBase { class ReadVisitor : public VisitorBase {
public: public:
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos) ReadVisitor(const SerializedSpan span, size_t pos)
: span_(span), result_(pos) {} : span_(span), result_(pos) {}
~ReadVisitor() { ~ReadVisitor() {
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced. HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
@ -236,7 +235,7 @@ class ReadVisitor : public VisitorBase {
} }
private: private:
const hwy::Span<const uint32_t> span_; const SerializedSpan span_;
IFields::ReadResult result_; IFields::ReadResult result_;
// Stack of end positions of nested IFields. Updated in operator()(IFields&), // Stack of end positions of nested IFields. Updated in operator()(IFields&),
// but read in SkipField. // but read in SkipField.
@ -326,8 +325,7 @@ void IFields::Print() const {
visitor(*const_cast<IFields*>(this)); visitor(*const_cast<IFields*>(this));
} }
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span, IFields::ReadResult IFields::Read(const SerializedSpan span, size_t pos) {
size_t pos) {
ReadVisitor visitor(span, pos); ReadVisitor visitor(span, pos);
visitor(*this); visitor(*this);
return visitor.Result(); return visitor.Result();

View File

@ -27,7 +27,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" #include "hwy/base.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -133,6 +133,8 @@ class IFieldsVisitor {
bool any_invalid_ = false; bool any_invalid_ = false;
}; };
using SerializedSpan = hwy::Span<const uint32_t>;
// Abstract base class for user-defined serializable classes, which are // Abstract base class for user-defined serializable classes, which are
// forward- and backward compatible collection of fields (members). This means // forward- and backward compatible collection of fields (members). This means
// old code can safely read new data, and new code can still handle old data. // old code can safely read new data, and new code can still handle old data.
@ -178,13 +180,13 @@ struct IFields {
// the code, but valid, and extra_u32 should be zero. // the code, but valid, and extra_u32 should be zero.
uint32_t missing_fields; uint32_t missing_fields;
// How many extra u32 are in the stored size, vs. what we actually read as // How many extra u32 are in the stored size, vs. what we actually read as
// requested by VisitFields. If non-zero,, the data is newer than the code, // requested by VisitFields. If non-zero, the data is newer than the code,
// but valid, and missing_fields should be zero. // but valid, and missing_fields should be zero.
uint32_t extra_u32; uint32_t extra_u32;
}; };
// Reads fields starting at `span[pos]`. // Reads fields starting at `span[pos]`.
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos); ReadResult Read(SerializedSpan span, size_t pos);
// Returns false if there was an unrecoverable error, typically because a // Returns false if there was an unrecoverable error, typically because a
// field has an invalid value. If so, `storage` is undefined. // field has an invalid value. If so, `storage` is undefined.

View File

@ -24,9 +24,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/shared.h" #include "util/basics.h" // BF16
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -162,8 +161,8 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
float EmbeddingScaling(size_t model_dim) { float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul. // Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>( return hwy::ConvertScalarTo<float>(
sqrtf(static_cast<float>(model_dim)))); hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
} }
float ChooseQueryScale(const ModelConfig& config) { float ChooseQueryScale(const ModelConfig& config) {

View File

@ -20,9 +20,9 @@
#include <string> #include <string>
#include "compression/shared.h" // PromptWrapping #include "compression/shared.h" // Type
#include "gemma/configs.h" // IWYU pragma: export #include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo #include "hwy/base.h" // ConvertScalarTo
namespace gcpp { namespace gcpp {

View File

@ -67,10 +67,7 @@ struct TokenAndProb {
// Entire size of a 2D array. // Entire size of a 2D array.
struct Extents2D { struct Extents2D {
constexpr Extents2D() : rows(0), cols(0) {} constexpr Extents2D() : rows(0), cols(0) {}
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) { constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {}
HWY_DASSERT(rows != 0);
HWY_DASSERT(cols != 0);
}
size_t Area() const { return rows * cols; } size_t Area() const { return rows * cols; }