mirror of https://github.com/google/gemma.cpp.git
Minor cleanup: enable 0,0 Extents2D, add SerializedSpan typedef, include fixes
PiperOrigin-RevId: 745068776
This commit is contained in:
parent
76a81ac2d6
commit
4e6aa36e9b
|
|
@ -265,10 +265,10 @@ cc_library(
|
|||
"gemma/tensor_index.h",
|
||||
],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:fields",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -115,7 +114,7 @@ class PrintVisitor : public VisitorBase {
|
|||
|
||||
class ReadVisitor : public VisitorBase {
|
||||
public:
|
||||
ReadVisitor(const hwy::Span<const uint32_t>& span, size_t pos)
|
||||
ReadVisitor(const SerializedSpan span, size_t pos)
|
||||
: span_(span), result_(pos) {}
|
||||
~ReadVisitor() {
|
||||
HWY_ASSERT(end_.empty()); // Bug if push/pop are not balanced.
|
||||
|
|
@ -236,7 +235,7 @@ class ReadVisitor : public VisitorBase {
|
|||
}
|
||||
|
||||
private:
|
||||
const hwy::Span<const uint32_t> span_;
|
||||
const SerializedSpan span_;
|
||||
IFields::ReadResult result_;
|
||||
// Stack of end positions of nested IFields. Updated in operator()(IFields&),
|
||||
// but read in SkipField.
|
||||
|
|
@ -326,8 +325,7 @@ void IFields::Print() const {
|
|||
visitor(*const_cast<IFields*>(this));
|
||||
}
|
||||
|
||||
IFields::ReadResult IFields::Read(const hwy::Span<const uint32_t>& span,
|
||||
size_t pos) {
|
||||
IFields::ReadResult IFields::Read(const SerializedSpan span, size_t pos) {
|
||||
ReadVisitor visitor(span, pos);
|
||||
visitor(*this);
|
||||
return visitor.Result();
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
|
|
@ -133,6 +133,8 @@ class IFieldsVisitor {
|
|||
bool any_invalid_ = false;
|
||||
};
|
||||
|
||||
using SerializedSpan = hwy::Span<const uint32_t>;
|
||||
|
||||
// Abstract base class for user-defined serializable classes, which are
|
||||
// forward- and backward compatible collection of fields (members). This means
|
||||
// old code can safely read new data, and new code can still handle old data.
|
||||
|
|
@ -178,13 +180,13 @@ struct IFields {
|
|||
// the code, but valid, and extra_u32 should be zero.
|
||||
uint32_t missing_fields;
|
||||
// How many extra u32 are in the stored size, vs. what we actually read as
|
||||
// requested by VisitFields. If non-zero,, the data is newer than the code,
|
||||
// requested by VisitFields. If non-zero, the data is newer than the code,
|
||||
// but valid, and missing_fields should be zero.
|
||||
uint32_t extra_u32;
|
||||
};
|
||||
|
||||
// Reads fields starting at `span[pos]`.
|
||||
ReadResult Read(const hwy::Span<const uint32_t>& span, size_t pos);
|
||||
ReadResult Read(SerializedSpan span, size_t pos);
|
||||
|
||||
// Returns false if there was an unrecoverable error, typically because a
|
||||
// field has an invalid value. If so, `storage` is undefined.
|
||||
|
|
|
|||
|
|
@ -24,9 +24,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -162,8 +161,8 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
|||
|
||||
float EmbeddingScaling(size_t model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
sqrtf(static_cast<float>(model_dim))));
|
||||
return hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(sqrtf(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
float ChooseQueryScale(const ModelConfig& config) {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
|
|
|
|||
|
|
@ -67,10 +67,7 @@ struct TokenAndProb {
|
|||
// Entire size of a 2D array.
|
||||
struct Extents2D {
|
||||
constexpr Extents2D() : rows(0), cols(0) {}
|
||||
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {
|
||||
HWY_DASSERT(rows != 0);
|
||||
HWY_DASSERT(cols != 0);
|
||||
}
|
||||
constexpr Extents2D(size_t rows, size_t cols) : rows(rows), cols(cols) {}
|
||||
|
||||
size_t Area() const { return rows * cols; }
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue