Minor cleanup: enable 0,0 Extents2D, add SerializedSpan typedef, include fixes

PiperOrigin-RevId: 745068776
This commit is contained in:
Jan Wassenberg 2025-04-08 03:35:08 -07:00 committed by Copybara-Service
parent 76a81ac2d6
commit 4e6aa36e9b
6 changed files with 16 additions and 20 deletions

View File

@ -265,10 +265,10 @@ cc_library(
"gemma/tensor_index.h",
],
deps = [
":basics",
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
)

View File

@ -24,7 +24,6 @@
#include <string>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
namespace gcpp {
@ -115,7 +114,7 @@ class PrintVisitor : public VisitorBase {
class ReadVisitor : public VisitorBase {
public:
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos)
ReadVisitor(const SerializedSpan span, size_t pos)
: span_(span), result_(pos) {}
~ReadVisitor() {
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
@ -236,7 +235,7 @@ class ReadVisitor : public VisitorBase {
}
private:
const hwy::Span<const uint32_t> span_;
const SerializedSpan span_;
IFields::ReadResult result_;
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
// but read in SkipField.
@ -326,8 +325,7 @@ void IFields::Print() const {
visitor(*const_cast<IFields*>(this));
}
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span,
size_t pos) {
IFields::ReadResult IFields::Read(const SerializedSpan span, size_t pos) {
ReadVisitor visitor(span, pos);
visitor(*this);
return visitor.Result();

View File

@ -27,7 +27,7 @@
#include <string>
#include <vector>
#include "hwy/aligned_allocator.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
// IWYU pragma: end_exports
@ -133,6 +133,8 @@ class IFieldsVisitor {
bool any_invalid_ = false;
};
using SerializedSpan = hwy::Span<const uint32_t>;
// Abstract base class for user-defined serializable classes, which are
// forward- and backward compatible collection of fields (members). This means
// old code can safely read new data, and new code can still handle old data.
@ -178,13 +180,13 @@ struct IFields {
// the code, but valid, and extra_u32 should be zero.
uint32_t missing_fields;
// How many extra u32 are in the stored size, vs. what we actually read as
// requested by VisitFields. If non-zero,, the data is newer than the code,
// requested by VisitFields. If non-zero, the data is newer than the code,
// but valid, and missing_fields should be zero.
uint32_t extra_u32;
};
// Reads fields starting at `span[pos]`.
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos);
ReadResult Read(SerializedSpan span, size_t pos);
// Returns false if there was an unrecoverable error, typically because a
// field has an invalid value. If so, `storage` is undefined.

View File

@ -24,9 +24,8 @@
#include <string>
#include <vector>
#include "compression/shared.h"
#include "util/basics.h" // BF16
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
@ -162,8 +161,8 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
sqrtf(static_cast<float>(model_dim))));
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
}
float ChooseQueryScale(const ModelConfig& config) {

View File

@ -20,7 +20,7 @@
#include <string>
#include "compression/shared.h" // PromptWrapping
#include "compression/shared.h" // Type
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo

View File

@ -67,10 +67,7 @@ struct TokenAndProb {
// Entire size of a 2D array.
struct Extents2D {
constexpr Extents2D() : rows(0), cols(0) {}
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
HWY_DASSERT(rows != 0);
HWY_DASSERT(cols != 0);
}
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {}
size_t Area() const { return rows * cols; }