mirror of https://github.com/google/gemma.cpp.git
Minor cleanup: enable 0,0 Extents2D, add SerializedSpan typedef, include fixes
PiperOrigin-RevId: 745068776
This commit is contained in:
parent
76a81ac2d6
commit
4e6aa36e9b
|
|
@ -265,10 +265,10 @@ cc_library(
|
||||||
"gemma/tensor_index.h",
|
"gemma/tensor_index.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":basics",
|
||||||
"//compression:fields",
|
"//compression:fields",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"@highway//:hwy", # base.h
|
"@highway//:hwy", # base.h
|
||||||
"@highway//:thread_pool",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -115,7 +114,7 @@ class PrintVisitor : public VisitorBase {
|
||||||
|
|
||||||
class ReadVisitor : public VisitorBase {
|
class ReadVisitor : public VisitorBase {
|
||||||
public:
|
public:
|
||||||
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos)
|
ReadVisitor(const SerializedSpan span, size_t pos)
|
||||||
: span_(span), result_(pos) {}
|
: span_(span), result_(pos) {}
|
||||||
~ReadVisitor() {
|
~ReadVisitor() {
|
||||||
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
|
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
|
||||||
|
|
@ -236,7 +235,7 @@ class ReadVisitor : public VisitorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const hwy::Span<const uint32_t> span_;
|
const SerializedSpan span_;
|
||||||
IFields::ReadResult result_;
|
IFields::ReadResult result_;
|
||||||
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
|
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
|
||||||
// but read in SkipField.
|
// but read in SkipField.
|
||||||
|
|
@ -326,8 +325,7 @@ void IFields::Print() const {
|
||||||
visitor(*const_cast<IFields*>(this));
|
visitor(*const_cast<IFields*>(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span,
|
IFields::ReadResult IFields::Read(const SerializedSpan span, size_t pos) {
|
||||||
size_t pos) {
|
|
||||||
ReadVisitor visitor(span, pos);
|
ReadVisitor visitor(span, pos);
|
||||||
visitor(*this);
|
visitor(*this);
|
||||||
return visitor.Result();
|
return visitor.Result();
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
|
@ -133,6 +133,8 @@ class IFieldsVisitor {
|
||||||
bool any_invalid_ = false;
|
bool any_invalid_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using SerializedSpan = hwy::Span<const uint32_t>;
|
||||||
|
|
||||||
// Abstract base class for user-defined serializable classes, which are
|
// Abstract base class for user-defined serializable classes, which are
|
||||||
// forward- and backward compatible collection of fields (members). This means
|
// forward- and backward compatible collection of fields (members). This means
|
||||||
// old code can safely read new data, and new code can still handle old data.
|
// old code can safely read new data, and new code can still handle old data.
|
||||||
|
|
@ -178,13 +180,13 @@ struct IFields {
|
||||||
// the code, but valid, and extra_u32 should be zero.
|
// the code, but valid, and extra_u32 should be zero.
|
||||||
uint32_t missing_fields;
|
uint32_t missing_fields;
|
||||||
// How many extra u32 are in the stored size, vs. what we actually read as
|
// How many extra u32 are in the stored size, vs. what we actually read as
|
||||||
// requested by VisitFields. If non-zero,, the data is newer than the code,
|
// requested by VisitFields. If non-zero, the data is newer than the code,
|
||||||
// but valid, and missing_fields should be zero.
|
// but valid, and missing_fields should be zero.
|
||||||
uint32_t extra_u32;
|
uint32_t extra_u32;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Reads fields starting at `span[pos]`.
|
// Reads fields starting at `span[pos]`.
|
||||||
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos);
|
ReadResult Read(SerializedSpan span, size_t pos);
|
||||||
|
|
||||||
// Returns false if there was an unrecoverable error, typically because a
|
// Returns false if there was an unrecoverable error, typically because a
|
||||||
// field has an invalid value. If so, `storage` is undefined.
|
// field has an invalid value. If so, `storage` is undefined.
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/shared.h"
|
#include "util/basics.h" // BF16
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -162,8 +161,8 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||||
|
|
||||||
float EmbeddingScaling(size_t model_dim) {
|
float EmbeddingScaling(size_t model_dim) {
|
||||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
return hwy::ConvertScalarTo<float>(
|
||||||
sqrtf(static_cast<float>(model_dim))));
|
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
|
||||||
}
|
}
|
||||||
|
|
||||||
float ChooseQueryScale(const ModelConfig& config) {
|
float ChooseQueryScale(const ModelConfig& config) {
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "compression/shared.h" // PromptWrapping
|
#include "compression/shared.h" // Type
|
||||||
#include "gemma/configs.h" // IWYU pragma: export
|
#include "gemma/configs.h" // IWYU pragma: export
|
||||||
#include "hwy/base.h" // ConvertScalarTo
|
#include "hwy/base.h" // ConvertScalarTo
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,10 +67,7 @@ struct TokenAndProb {
|
||||||
// Entire size of a 2D array.
|
// Entire size of a 2D array.
|
||||||
struct Extents2D {
|
struct Extents2D {
|
||||||
constexpr Extents2D() : rows(0), cols(0) {}
|
constexpr Extents2D() : rows(0), cols(0) {}
|
||||||
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {}
|
||||||
HWY_DASSERT(rows != 0);
|
|
||||||
HWY_DASSERT(cols != 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t Area() const { return rows * cols; }
|
size_t Area() const { return rows * cols; }
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue