mirror of https://github.com/google/gemma.cpp.git
616 lines
21 KiB
C++
616 lines
21 KiB
C++
// Copyright 2023 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Tensor metadata and in-memory representation.
|
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <string>
|
|
|
|
// IWYU pragma: begin_exports
|
|
#include "compression/types.h" // Type
|
|
#include "gemma/tensor_info.h"
|
|
#include "io/fields.h"
|
|
#include "util/allocator.h" // AlignedPtr
|
|
#include "util/basics.h" // Extents2D
|
|
// IWYU pragma: end_exports
|
|
#include "hwy/base.h"
|
|
|
|
namespace gcpp {
|
|
|
|
// Type-safe wrapper over type-erased uint8_t row pointers from MatPtr. Used
|
|
// for C (KV output), in future also for A or even B.
|
|
template <typename T>
|
|
class RowPtrs {
|
|
public:
|
|
RowPtrs(uint8_t** row_ptrs) : row_ptrs_(row_ptrs), r0_(0), c0_(0) {}
|
|
|
|
// Extra argument is for compatibility with `StridedView`.
|
|
RowPtrs View(size_t r, size_t c, size_t /*cols*/) {
|
|
RowPtrs<T> view(row_ptrs_);
|
|
view.r0_ = static_cast<uint32_t>(r0_ + r);
|
|
view.c0_ = static_cast<uint32_t>(c0_ + c);
|
|
return view;
|
|
}
|
|
|
|
T* HWY_RESTRICT Row(size_t row_idx) const {
|
|
return HWY_RCAST_ALIGNED(T*, row_ptrs_[r0_ + row_idx]) + c0_;
|
|
}
|
|
|
|
private:
|
|
uint8_t** row_ptrs_;
|
|
uint32_t r0_;
|
|
uint32_t c0_;
|
|
};
|
|
|
|
using RowPtrsBF = RowPtrs<BF16>;
|
|
|
|
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
|
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
|
|
// to store hetereogeneous tensor references in a vector.
|
|
//
|
|
// Copyable, (de)serializable via `fields.h` for `model_store.h`.
|
|
class MatPtr : public IFields {
|
|
public:
|
|
MatPtr() = default;
|
|
// `name`: see `SetName`. Note that `stride` is initially `cols` and only
|
|
// differs after deserializing, or calling `SetPtr`.
|
|
MatPtr(const char* name, Type type, Extents2D extents)
|
|
: private_rows_(static_cast<uint32_t>(extents.rows)),
|
|
cols_(static_cast<uint32_t>(extents.cols)) {
|
|
SetName(name);
|
|
SetType(type);
|
|
SetPtr(nullptr, cols_);
|
|
}
|
|
|
|
// Copying allowed because the metadata is small.
|
|
MatPtr(const MatPtr& other) = default;
|
|
MatPtr& operator=(const MatPtr& other) = default;
|
|
|
|
virtual ~MatPtr() = default;
|
|
|
|
// Only for use by ctor, `AllocateFor` and 'loading' memory-mapped tensors.
|
|
void SetPtr(void* ptr, size_t stride) {
|
|
HWY_ASSERT(stride >= Cols());
|
|
ptr_ = ptr;
|
|
stride_ = static_cast<uint32_t>(stride);
|
|
|
|
// If row pointers were already attached, `SetPtr` would invalidate them.
|
|
HWY_DASSERT_M(row_ptrs_ == nullptr, "Do not call after AttachRowPtrs.");
|
|
|
|
// NUQ streams must not be padded because that would change the position of
|
|
// the group tables.
|
|
if (type_ == Type::kNUQ) {
|
|
HWY_ASSERT_M(GEMMA_ENABLE_NUQ, "Set GEMMA_ENABLE_NUQ=1.");
|
|
HWY_ASSERT(IsPacked());
|
|
}
|
|
}
|
|
|
|
bool HasPtr() const { return ptr_ != nullptr; }
|
|
|
|
// Caller has initialized Rows() pointers in row_ptrs[]. Note that this only
|
|
// changes `GetRowPtrs`, not `Row()`, because that would require branching
|
|
// and only a few call sites, in particular MatMul, use row pointers.
|
|
void AttachRowPtrs(uint8_t** row_ptrs) {
|
|
row_ptrs_ = row_ptrs;
|
|
for (size_t r = 0; r < Rows(); ++r) {
|
|
HWY_DASSERT(row_ptrs[r] != nullptr);
|
|
}
|
|
}
|
|
|
|
// Called by Activations to allocate once, rather than have to fill row
|
|
// pointers in each call to MatMul.
|
|
void AllocateAndAttachRowPtrs(
|
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) {
|
|
if (!HasPtr()) return;
|
|
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(Rows()));
|
|
uint8_t** ptrs = row_ptrs.back().get();
|
|
for (size_t r = 0; r < Rows(); ++r) {
|
|
ptrs[r] = RowBytes(r);
|
|
}
|
|
AttachRowPtrs(ptrs);
|
|
};
|
|
|
|
// If non-null, this array should be used instead of `Row()`.
|
|
uint8_t** GetRowPtrs() const { return row_ptrs_; }
|
|
|
|
// A single row counts as packed because there is no padding between rows.
|
|
bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); }
|
|
|
|
const void* Packed() const {
|
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
|
return ptr_;
|
|
}
|
|
void* Packed() {
|
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
|
return ptr_;
|
|
}
|
|
|
|
// Returns size in bytes for purposes of copying/initializing or I/O. Must
|
|
// only be called if `IsPacked`.
|
|
size_t PackedBytes() const {
|
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
|
// num_elements_ already includes the NUQ tables.
|
|
return num_elements_ * element_bytes_;
|
|
}
|
|
|
|
// Works for any kind of padding and element type.
|
|
uint8_t* RowBytes(size_t row) {
|
|
HWY_DASSERT(row < Rows());
|
|
return static_cast<uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
|
|
}
|
|
const uint8_t* RowBytes(size_t row) const {
|
|
HWY_DASSERT(row < Rows());
|
|
return static_cast<const uint8_t*>(ptr_) + row * (stride_ * element_bytes_);
|
|
}
|
|
|
|
Type GetType() const { return type_; }
|
|
void SetType(Type type) {
|
|
type_ = type;
|
|
if (type == Type::kUnknown) {
|
|
// Temporary invalid state. Happens during weights.h construction, before
|
|
// the ForEachTensor that loads them and sets the type.
|
|
element_bytes_ = 0;
|
|
num_elements_ = 0;
|
|
return;
|
|
}
|
|
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
|
|
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
|
|
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
|
|
}
|
|
|
|
size_t Rows() const {
|
|
return override_rows_ == 0 ? private_rows_ : override_rows_;
|
|
}
|
|
size_t Cols() const { return cols_; }
|
|
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
|
|
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
|
|
bool SameShape(const MatPtr& other) const {
|
|
return Rows() == other.Rows() && Cols() == other.Cols();
|
|
}
|
|
void DebugCheckSameShape(const MatPtr& other) const {
|
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
if (!SameShape(other)) {
|
|
HWY_ABORT("%s: shape mismatch %zu x %zu vs %zu x %zu\n", name_.c_str(),
|
|
Rows(), Cols(), other.Rows(), Cols());
|
|
}
|
|
}
|
|
}
|
|
// Future calls to `Rows()` during this class' lifetime (not serialized)
|
|
// will return this value. Used to set the actual number of rows for
|
|
// activations preallocated according to the batch size.
|
|
void OverrideRows(size_t rows) {
|
|
if (HWY_UNLIKELY(rows > private_rows_)) {
|
|
HWY_ABORT("%s: rows %zu > private_rows_ %u\n", name_.c_str(), rows,
|
|
private_rows_);
|
|
}
|
|
override_rows_ = static_cast<uint32_t>(rows);
|
|
}
|
|
|
|
// Offset by which to advance pointers to the next row.
|
|
size_t Stride() const { return stride_; }
|
|
|
|
// For use by `BlobStore`, `CopyMat` and `ZeroInit`.
|
|
size_t ElementBytes() const { return element_bytes_; }
|
|
|
|
// Decoded elements should be multiplied by this to restore their original
|
|
// range. This is required because `SfpStream` can only encode a limited range
|
|
// of magnitudes.
|
|
float Scale() const { return scale_; }
|
|
void SetScale(float scale) { scale_ = scale; }
|
|
|
|
// A terse identifier unique across all tensors of the model.
|
|
const char* Name() const override { return name_.c_str(); }
|
|
// `MakeKey` in `blob_store.cc` requires that this be <= 16 bytes, including
|
|
// the `LayerSuffix` for per-layer tensors.
|
|
void SetName(const char* name) {
|
|
name_ = name;
|
|
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
|
|
}
|
|
|
|
void VisitFields(IFieldsVisitor& visitor) override {
|
|
// Order determines the order of serialization and must not change.
|
|
visitor(name_);
|
|
visitor(type_);
|
|
visitor(element_bytes_);
|
|
visitor(num_elements_);
|
|
visitor(private_rows_);
|
|
visitor(cols_);
|
|
visitor(scale_);
|
|
visitor(stride_);
|
|
}
|
|
|
|
protected:
|
|
// For initializing `num_elements_`: "elements" are how many objects we
|
|
// actually store in order to represent rows * cols values. For NUQ, this is
|
|
// greater because it includes additional per-group tables. This is the only
|
|
// place where we compute this fixup. Note that elements are independent of
|
|
// padding, which is anyway not supported for NUQ because `compress-inl.h`
|
|
// assumes a contiguous stream for its group indexing.
|
|
static size_t ComputeNumElements(Type type, Extents2D extents) {
|
|
size_t num_elements = extents.Area();
|
|
if (type == Type::kNUQ) {
|
|
// `CompressedArrayElements` is a wrapper function that has the same
|
|
// effect, but that requires a template argument, not `type`.
|
|
num_elements = NuqStream::PackedEnd(num_elements);
|
|
} else if (type == Type::kI8) {
|
|
num_elements = I8Stream::PackedEnd(num_elements);
|
|
}
|
|
return num_elements;
|
|
}
|
|
|
|
std::string name_; // See `SetName`.
|
|
Type type_;
|
|
|
|
// Most members are u32 because that is the preferred type of fields.h.
|
|
|
|
// Bytes per element. This is fully determined by `type_`, but stored here
|
|
// for convenience and backward compatibility.
|
|
uint32_t element_bytes_ = 0;
|
|
// Number of elements to store (including NUQ tables but not padding).
|
|
// This a function of `type_` and `Extents()` and stored for compatibility.
|
|
uint32_t num_elements_ = 0;
|
|
uint32_t private_rows_ = 0; // Only access via Rows()! See OverrideRows().
|
|
uint32_t cols_ = 0;
|
|
|
|
uint32_t override_rows_ = 0; // not serialized
|
|
|
|
// Non-owning pointer, must not be freed. The underlying memory must outlive
|
|
// this object.
|
|
void* ptr_ = nullptr; // not serialized
|
|
|
|
// Points to an array of pointers, one per row, or nullptr if `AttachRowPtrs`
|
|
// was not called. Only used for MatMul output tensors, hence we
|
|
// minimize the cost for other tensors by only holding a non-owning pointer.
|
|
uint8_t** row_ptrs_ = nullptr; // not serialized
|
|
|
|
// Offset by which to advance pointers to the next row, >= `cols_`.
|
|
uint32_t stride_;
|
|
|
|
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
|
};
|
|
|
|
// Non-type erased version of `MatPtr`: provides type-safe `Row()` and ensures
|
|
// the template argument and `Type` are consistent.
|
|
template <typename MatT>
|
|
class MatPtrT : public MatPtr {
|
|
public:
|
|
using T = MatT;
|
|
|
|
// Default constructor for use with uninitialized views.
|
|
MatPtrT() = default;
|
|
|
|
// Called by `MatStorageT`.
|
|
MatPtrT(const char* name, Extents2D extents)
|
|
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
|
|
|
// Copying allowed because the metadata is small.
|
|
MatPtrT(const MatPtr& other) : MatPtr(other) {
|
|
// Happens in weights.h when constructing via MatFinder, which does not
|
|
// know the type. Setting the type here avoids having to keep the
|
|
// initializer list and member type in sync.
|
|
if (GetType() == Type::kUnknown) {
|
|
SetType(TypeEnum<MatT>());
|
|
} else {
|
|
if (HWY_UNLIKELY(other.GetType() != TypeEnum<MatT>())) {
|
|
HWY_ABORT("Type mismatch: MatT %s, constructing from %s",
|
|
TypeName<MatT>(), TypeName(other.GetType()));
|
|
}
|
|
}
|
|
}
|
|
MatPtrT& operator=(const MatPtr& other) {
|
|
MatPtr::operator=(other);
|
|
return *this;
|
|
}
|
|
MatPtrT(const MatPtrT& other) = default;
|
|
MatPtrT& operator=(const MatPtrT& other) = default;
|
|
|
|
// Returns the entire tensor after checking the scale is 1.0 because callers
|
|
// will ignore it. Used for `MatMul` bias vectors and norm weights.
|
|
const MatT* PackedScale1() const {
|
|
HWY_DASSERT(Scale() == 1.0f);
|
|
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
|
}
|
|
|
|
MatT* Row(size_t row) { return HWY_RCAST_ALIGNED(T*, RowBytes(row)); }
|
|
const MatT* Row(size_t row) const {
|
|
return HWY_RCAST_ALIGNED(const T*, RowBytes(row));
|
|
}
|
|
|
|
hwy::Span<MatT> RowSpan(size_t row) {
|
|
return hwy::Span<MatT>(Row(row), Cols());
|
|
}
|
|
hwy::Span<const MatT> RowSpan(size_t row) const {
|
|
return hwy::Span<const MatT>(Row(row), Cols());
|
|
}
|
|
|
|
PackedSpan<const MatT> PaddedSpan() const {
|
|
const size_t num = IsPacked() ? num_elements_ : Rows() * Stride();
|
|
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num);
|
|
}
|
|
|
|
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
|
// require packed layout.
|
|
PackedSpan<MatT> Span() {
|
|
HWY_ASSERT(IsPacked());
|
|
return MakeSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
|
|
}
|
|
PackedSpan<const MatT> Span() const {
|
|
HWY_ASSERT(IsPacked());
|
|
return MakeConstSpan(HWY_RCAST_ALIGNED(MatT*, ptr_), num_elements_);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
RowPtrs<T> GetOrSetTempRowPtrs(
|
|
const MatPtrT<T>& mat,
|
|
const hwy::AlignedFreeUniquePtr<uint8_t*[]>& storage) {
|
|
if (HWY_LIKELY(mat.GetRowPtrs())) return RowPtrs<T>(mat.GetRowPtrs());
|
|
|
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
fprintf(stderr,
|
|
"MatMul perf warning: setting row pointers because "
|
|
"%s.AttachRowPtrs() was not called.\n",
|
|
mat.Name());
|
|
}
|
|
HWY_DASSERT(mat.HasPtr());
|
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
|
storage[r] = reinterpret_cast<uint8_t*>(const_cast<T*>(mat.Row(r)));
|
|
}
|
|
return RowPtrs<T>(storage.get());
|
|
}
|
|
|
|
// Calls `func` with `MatPtrT<T>*` plus the optional `args`. This supports all
|
|
// types used as weights.
|
|
template <class Func, typename... Args>
|
|
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
|
|
Args&&... args) {
|
|
if constexpr (GEMMA_ENABLE_NUQ) {
|
|
if (base->GetType() == Type::kNUQ) {
|
|
const MatPtrT<NuqStream> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
}
|
|
}
|
|
|
|
if (base->GetType() == Type::kF32) {
|
|
const MatPtrT<float> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else if (base->GetType() == Type::kBF16) {
|
|
const MatPtrT<BF16> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else if (base->GetType() == Type::kSFP) {
|
|
const MatPtrT<SfpStream> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else if (base->GetType() == Type::kI8) {
|
|
const MatPtrT<I8Stream> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else {
|
|
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
|
}
|
|
}
|
|
|
|
// Calls `func(base1, base2, args...)`.
|
|
template <class Func, typename... Args>
|
|
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
|
|
const Func& func, Args&&... args) {
|
|
HWY_DASSERT(base1->GetType() == base2->GetType());
|
|
|
|
if constexpr (GEMMA_ENABLE_NUQ) {
|
|
if (base1->GetType() == Type::kNUQ) {
|
|
const MatPtrT<NuqStream> mat1(*base1);
|
|
const MatPtrT<NuqStream> mat2(*base2);
|
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
|
}
|
|
}
|
|
|
|
if (base1->GetType() == Type::kF32) {
|
|
const MatPtrT<float> mat1(*base1);
|
|
const MatPtrT<float> mat2(*base2);
|
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
|
} else if (base1->GetType() == Type::kBF16) {
|
|
const MatPtrT<BF16> mat1(*base1);
|
|
const MatPtrT<BF16> mat2(*base2);
|
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
|
} else if (base1->GetType() == Type::kSFP) {
|
|
const MatPtrT<SfpStream> mat1(*base1);
|
|
const MatPtrT<SfpStream> mat2(*base2);
|
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
|
} else if (base1->GetType() == Type::kI8) {
|
|
const MatPtrT<I8Stream> mat1(*base1);
|
|
const MatPtrT<I8Stream> mat2(*base2);
|
|
return func(&mat1, &mat2, std::forward<Args>(args)...);
|
|
} else {
|
|
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
|
|
}
|
|
}
|
|
|
|
// Like CallUpcasted, but only for activation types: kBF16 and kF32.
|
|
template <class Func, typename... Args>
|
|
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
|
Args&&... args) {
|
|
if (base->GetType() == Type::kF32) {
|
|
const MatPtrT<float> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else if (base->GetType() == Type::kBF16) {
|
|
const MatPtrT<BF16> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else {
|
|
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
|
}
|
|
}
|
|
|
|
// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32.
|
|
template <class Func, typename... Args>
|
|
decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
|
|
Args&&... args) {
|
|
if (base->GetType() == Type::kF32) {
|
|
const MatPtrT<float> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else if (base->GetType() == Type::kBF16) {
|
|
const MatPtrT<BF16> mat(*base);
|
|
return func(&mat, std::forward<Args>(args)...);
|
|
} else {
|
|
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
|
}
|
|
}
|
|
|
|
void CopyMat(const MatPtr& from, MatPtr& to);
|
|
void ZeroInit(MatPtr& mat);
|
|
|
|
// Our tensors are always row-major. This enum indicates how much (if any)
|
|
// padding comes after each row.
|
|
enum class MatPadding {
|
|
// None, stride == cols. `compress-inl.h` requires this layout because its
|
|
// interface assumes a continuous 1D array, without awareness of rows. Note
|
|
// that tensors which were written via `compress-inl.h` (i.e. most in
|
|
// `BlobStore`) are not padded, which also extends to memory-mapped tensors.
|
|
// However, `BlobStore` is able to insert padding via row-wise I/O when
|
|
// reading from disk via `Mode::kRead`.
|
|
kPacked,
|
|
// Enough to round up to an odd number of cache lines, which can reduce
|
|
// cache conflict misses or 4K aliasing.
|
|
kOdd,
|
|
};
|
|
|
|
// The stride (offset in elements between rows) that `MatOwner/MatStorageT`
|
|
// will use.
|
|
size_t Stride(MatPadding padding, size_t cols, size_t element_bytes,
|
|
size_t line_bytes);
|
|
|
|
// Type-erased, allows storing `AlignedPtr<T[]>` for various T in the same
|
|
// vector.
|
|
class MatOwner {
|
|
public:
|
|
MatOwner() = default;
|
|
// Allow move for `MatStorageT`.
|
|
MatOwner(MatOwner&&) = default;
|
|
MatOwner& operator=(MatOwner&&) = default;
|
|
|
|
// Allocates the type/extents indicated by `mat` and sets its pointer.
|
|
// Ignores `padding` for NUQ tensors, which are always packed.
|
|
// Thread-compatible, weights are allocated in parallel.
|
|
void AllocateFor(MatPtr& mat, const Allocator& allocator, MatPadding padding);
|
|
|
|
private:
|
|
AlignedPtr<uint8_t[]> storage_;
|
|
};
|
|
|
|
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by tests to allocate
|
|
// and access tensors of a known type. By contrast, the heterogeneous model
|
|
// weights are owned by vectors of `MatOwner`.
|
|
template <typename MatT>
|
|
class MatStorageT : public MatPtrT<MatT> {
|
|
public:
|
|
MatStorageT() = default; // for std::vector in Activations.
|
|
MatStorageT(const char* name, Extents2D extents, const Allocator& allocator,
|
|
MatPadding padding)
|
|
: MatPtrT<MatT>(name, extents) {
|
|
if (extents.Area() != 0) owner_.AllocateFor(*this, allocator, padding);
|
|
}
|
|
// Shorthand for 1D tensors: packing does not help, hence `kPacked`.
|
|
MatStorageT(const char* name, size_t cols, const Allocator& allocator)
|
|
: MatStorageT(name, Extents2D(1, cols), allocator, MatPadding::kPacked) {}
|
|
~MatStorageT() = default;
|
|
|
|
// Allow move for KVCache.
|
|
MatStorageT(MatStorageT&&) = default;
|
|
MatStorageT& operator=(MatStorageT&&) = default;
|
|
|
|
private:
|
|
MatOwner owner_;
|
|
};
|
|
|
|
// Helper for initializing members which are `MatStorageT<T>`: avoids having to
|
|
// specify Extents2D and MatPadding at each call site.
|
|
class MatFactory {
|
|
public:
|
|
// The constructor captures all the necessary arguments.
|
|
MatFactory(const char* name, size_t rows, size_t cols,
|
|
const Allocator& allocator, MatPadding padding = MatPadding::kOdd)
|
|
: name_(name),
|
|
extents_(rows, cols),
|
|
allocator_(allocator),
|
|
padding_(padding) {}
|
|
|
|
// Templated conversion so we do not have to specify the type in the
|
|
// member initializer.
|
|
template <typename T>
|
|
operator MatStorageT<T>() const {
|
|
return MatStorageT<T>(name_.c_str(), extents_, allocator_, padding_);
|
|
}
|
|
|
|
private:
|
|
const std::string name_;
|
|
Extents2D extents_;
|
|
const Allocator& allocator_;
|
|
MatPadding padding_;
|
|
};
|
|
|
|
// Lightweight view into `MatStorageT`, with a fixed pitch/stride between rows.
|
|
// Also used to decompress B, hence non-const.
|
|
#pragma pack(push, 1) // power of two size
|
|
template <typename T>
|
|
class StridedView {
|
|
public:
|
|
StridedView(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
|
: row0_(row0),
|
|
cols_(static_cast<uint32_t>(cols)),
|
|
stride_(static_cast<uint32_t>(stride)) {
|
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
if (stride < cols) {
|
|
HWY_ABORT("stride %zu < cols %zu", stride, cols);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
StridedView(const MatPtrT<T>& mat, size_t r, size_t c, size_t cols)
|
|
: StridedView(const_cast<T*>(mat.Row(r)) + c, cols, mat.Stride()) {
|
|
HWY_DASSERT(c < mat.Cols());
|
|
HWY_DASSERT(cols <= mat.Cols() - c);
|
|
}
|
|
|
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
StridedView<T> View(size_t r, size_t c, size_t cols) const {
|
|
HWY_DASSERT(c < Cols());
|
|
HWY_DASSERT(cols <= Cols() - c);
|
|
return StridedView<T>(Row(r) + c, cols, stride_);
|
|
}
|
|
|
|
T* HWY_RESTRICT Row(size_t r) const { return row0_ + stride_ * r; }
|
|
size_t Cols() const { return static_cast<size_t>(cols_); }
|
|
|
|
size_t Stride() const { return static_cast<size_t>(stride_); }
|
|
void SetStride(size_t stride) {
|
|
HWY_DASSERT(stride >= Cols());
|
|
stride_ = stride;
|
|
}
|
|
|
|
private:
|
|
T* HWY_RESTRICT row0_;
|
|
uint32_t cols_;
|
|
uint32_t stride_;
|
|
};
|
|
#pragma pack(pop)
|
|
|
|
using StridedViewBF = StridedView<BF16>;
|
|
using StridedViewD = StridedView<double>;
|
|
|
|
} // namespace gcpp
|
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|