mirror of https://github.com/google/gemma.cpp.git
622 lines
26 KiB
C++
622 lines
26 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <complex>
|
|
#include <memory>
|
|
#include <mutex> // NOLINT
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "compression/types.h" // IsF32
|
|
#include "gemma/configs.h" // ModelConfig
|
|
#include "gemma/model_store.h" // ModelStore
|
|
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
|
#include "io/blob_store.h" // BlobWriter
|
|
#include "ops/matmul.h" // MatMulEnv
|
|
#include "util/mat.h" // MatPtr
|
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
|
|
namespace gcpp {
|
|
|
|
// Argument passed to the `ForEachTensor` callback.
|
|
struct TensorArgs {
|
|
// `other_mat1` and `other_mat2` can be nullptr, or tensor(s) of the same
|
|
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
|
|
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
|
|
// `flags` is a combination of zero or more `Flags`.
|
|
TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags)
|
|
: mat(mat),
|
|
other_mat1(other_mat1),
|
|
other_mat2(other_mat2),
|
|
flags(flags) {}
|
|
|
|
MatPtr& mat;
|
|
MatPtr* other_mat1; // either/both can be nullptr.
|
|
MatPtr* other_mat2;
|
|
|
|
enum Flags {
|
|
// Default: Read the tensor from the file and abort if it is not found.
|
|
kMustRead = 0,
|
|
|
|
// Not an error if the tensor is not present in the file. For example,
|
|
// the _w1/_w2 tensors are not always present.
|
|
kMaybeRead = 1,
|
|
|
|
// Avoid padding tensor rows when reading. Used for some Griffin tensors
|
|
// whose index computations do not use Row() accessors.
|
|
kPacked = 2,
|
|
};
|
|
const int flags;
|
|
};
|
|
|
|
// Shorthand for creating the argument to the `ForEachTensor` callback. A macro
|
|
// seems less bad than member pointer syntax.
|
|
#define TENSOR_ARGS(mat, flag) \
|
|
TensorArgs(mat, other1 ? &other1->mat : nullptr, \
|
|
other2 ? &other2->mat : nullptr, TensorArgs::flag)
|
|
|
|
// Per-layer weight metadata and pointers. The tensor data is owned by
|
|
// `WeightsOwner`. Note that this class could be type-erased: member functions
|
|
// do not actually use the `Weight` template argument. See `WeightsPtrs`.
|
|
// `TensorInfoRegistry` (constructed from `ModelConfig`) is the source of truth
|
|
// for all tensor shapes.
|
|
template <class Weight>
|
|
struct LayerWeightsPtrs {
|
|
static inline std::string Concat(const char* base_name,
|
|
const std::string& suffix) {
|
|
return std::string(base_name) + suffix;
|
|
}
|
|
|
|
// Initializes tensor metadata without allocating.
|
|
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
|
const TensorInfoRegistry& tensors)
|
|
: suffix_(LayerSuffix(layer_idx)),
|
|
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
|
|
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
|
|
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
|
|
attention_output_biases(Concat("attn_ob", suffix_), tensors),
|
|
griffin(
|
|
{.linear_x_w = {Concat("gr_lin_x_w", suffix_), tensors},
|
|
.linear_x_biases = {Concat("gr_lin_x_b", suffix_), tensors},
|
|
.linear_y_w = {Concat("gr_lin_y_w", suffix_), tensors},
|
|
.linear_y_biases = {Concat("gr_lin_y_b", suffix_), tensors},
|
|
.linear_out_w = {Concat("gr_lin_out_w", suffix_), tensors},
|
|
.linear_out_biases = {Concat("gr_lin_out_b", suffix_), tensors},
|
|
.conv_w = {Concat("gr_conv_w", suffix_), tensors},
|
|
.conv_biases = {Concat("gr_conv_b", suffix_), tensors},
|
|
.gate_w = {Concat("gr_gate_w", suffix_), tensors},
|
|
.gate_biases = {Concat("gr_gate_b", suffix_), tensors},
|
|
.a = {Concat("gr_a", suffix_), tensors}}),
|
|
// MultiHeadDotProductAttention.
|
|
vit({.attn_out_w = {Concat("attn_out_w", suffix_), tensors},
|
|
.attn_out_b = {Concat("attn_out_b", suffix_), tensors},
|
|
.qkv_einsum_w = {Concat("qkv_ein_w", suffix_), tensors},
|
|
.qkv_einsum_b = {Concat("qkv_ein_b", suffix_), tensors},
|
|
.linear_0_w = {Concat("linear_0_w", suffix_), tensors},
|
|
.linear_0_b = {Concat("linear_0_b", suffix_), tensors},
|
|
.linear_1_w = {Concat("linear_1_w", suffix_), tensors},
|
|
.linear_1_b = {Concat("linear_1_b", suffix_), tensors},
|
|
.layer_norm_0_bias = {Concat("ln_0_bias", suffix_), tensors},
|
|
.layer_norm_0_scale = {Concat("ln_0_scale", suffix_), tensors},
|
|
.layer_norm_1_bias = {Concat("ln_1_bias", suffix_), tensors},
|
|
.layer_norm_1_scale = {Concat("ln_1_scale", suffix_), tensors}}),
|
|
gating_einsum_w(Concat("gating_ein", suffix_), tensors),
|
|
gating_einsum_w1(Concat("gating1_w", suffix_), tensors),
|
|
gating_einsum_w2(Concat("gating2_w", suffix_), tensors),
|
|
linear_w(Concat("linear_w", suffix_), tensors),
|
|
pre_attention_norm_scale(Concat("pre_att_ns", suffix_), tensors),
|
|
pre_ffw_norm_scale(Concat("pre_ff_ns", suffix_), tensors),
|
|
post_attention_norm_scale(Concat("post_att_ns", suffix_), tensors),
|
|
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
|
|
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
|
|
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
|
|
|
|
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
|
|
att_weights(Concat("att_w", suffix_), tensors),
|
|
|
|
key_norm_scale(Concat("key_norm", suffix_), tensors),
|
|
query_norm_scale(Concat("query_norm", suffix_), tensors),
|
|
|
|
layer_config(config) {
|
|
}
|
|
~LayerWeightsPtrs() = default;
|
|
|
|
const std::string suffix_;
|
|
|
|
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
|
// do not yet support smaller compressed types, or require at least bf16. When
|
|
// weights are f32, we also want such tensors to be f32.
|
|
// If weights are complex, this is also complex.
|
|
using WeightF32OrBF16 =
|
|
hwy::If<hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
|
|
hwy::If<hwy::IsSame<Weight, double>(), double,
|
|
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
|
|
|
// Files either have qkv_einsum_w with 2 stacked matrices or separate
|
|
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
|
MatPtrT<Weight> qkv_einsum_w;
|
|
MatPtrT<Weight> qkv_einsum_w1;
|
|
MatPtrT<Weight> qkv_einsum_w2;
|
|
MatPtrT<float> attention_output_biases;
|
|
|
|
struct {
|
|
MatPtrT<Weight> linear_x_w;
|
|
MatPtrT<float> linear_x_biases;
|
|
MatPtrT<Weight> linear_y_w;
|
|
MatPtrT<float> linear_y_biases;
|
|
MatPtrT<Weight> linear_out_w;
|
|
MatPtrT<float> linear_out_biases;
|
|
MatPtrT<float> conv_w;
|
|
MatPtrT<float> conv_biases;
|
|
MatPtrT<Weight> gate_w;
|
|
MatPtrT<float> gate_biases;
|
|
MatPtrT<float> a;
|
|
} griffin;
|
|
|
|
struct {
|
|
// MultiHeadDotProductAttention.
|
|
MatPtrT<WeightF32OrBF16> attn_out_w;
|
|
MatPtrT<float> attn_out_b;
|
|
MatPtrT<WeightF32OrBF16> qkv_einsum_w;
|
|
MatPtrT<float> qkv_einsum_b;
|
|
// MlpBlock.
|
|
MatPtrT<WeightF32OrBF16> linear_0_w;
|
|
MatPtrT<float> linear_0_b;
|
|
MatPtrT<WeightF32OrBF16> linear_1_w;
|
|
MatPtrT<float> linear_1_b;
|
|
// LayerNorm.
|
|
MatPtrT<WeightF32OrBF16> layer_norm_0_bias;
|
|
MatPtrT<WeightF32OrBF16> layer_norm_0_scale;
|
|
MatPtrT<WeightF32OrBF16> layer_norm_1_bias;
|
|
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
|
|
} vit;
|
|
|
|
// Files either have gating_einsum_w with 2 stacked matrices or separate
|
|
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
|
MatPtrT<Weight> gating_einsum_w;
|
|
MatPtrT<Weight> gating_einsum_w1;
|
|
MatPtrT<Weight> gating_einsum_w2;
|
|
MatPtrT<Weight> linear_w;
|
|
// > W8 is likely helpful.
|
|
MatPtrT<WeightF32OrBF16> pre_attention_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> pre_ffw_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> post_attention_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> post_ffw_norm_scale;
|
|
|
|
MatPtrT<float> ffw_gating_biases;
|
|
MatPtrT<float> ffw_output_biases;
|
|
|
|
MatPtrT<Weight> attn_vec_einsum_w; // Use att_weights instead of this.
|
|
MatPtrT<Weight> att_weights; // Use this instead of attn_vec_einsum_w.
|
|
|
|
MatPtrT<WeightF32OrBF16> key_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> query_norm_scale;
|
|
|
|
const LayerConfig& layer_config;
|
|
|
|
// Calls `func(TensorArgs)` for each tensor which is in use for the
|
|
// current `layer_config`. `other1` and `other2` are optional arguments so we
|
|
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
|
// Public because also called by `WeightsPtrs`.
|
|
template <class Func>
|
|
void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
|
|
LayerWeightsPtrs<Weight>* other2, Func func) {
|
|
if (layer_config.type == LayerAttentionType::kVit) {
|
|
// MHA.
|
|
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.attn_out_b, kMustRead));
|
|
func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead));
|
|
// Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence
|
|
// must not be padded.
|
|
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked));
|
|
// MlpBlock.
|
|
func(TENSOR_ARGS(vit.linear_0_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_0_b, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_1_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_1_b, kMustRead));
|
|
// LayerNorm.
|
|
func(TENSOR_ARGS(vit.layer_norm_0_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_0_scale, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_1_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_1_scale, kMustRead));
|
|
return;
|
|
}
|
|
if (layer_config.type == LayerAttentionType::kGemma) {
|
|
// Either read from file, or allocated during Fixup().
|
|
func(TENSOR_ARGS(att_weights, kMaybeRead));
|
|
func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
|
|
} else {
|
|
func(TENSOR_ARGS(griffin.linear_x_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_y_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
|
|
// conv_w and gate_w are not accessed via Row(), hence must not be padded.
|
|
// Note that *biases are 1D, hence packing/padding does not matter.
|
|
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked));
|
|
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked));
|
|
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.a, kMustRead));
|
|
}
|
|
{
|
|
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
|
|
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
|
|
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
|
|
func(TENSOR_ARGS(linear_w, kMaybeRead));
|
|
func(TENSOR_ARGS(pre_attention_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(pre_ffw_norm_scale, kMustRead));
|
|
}
|
|
|
|
if (layer_config.post_norm == PostNormType::Scale) {
|
|
func(TENSOR_ARGS(post_attention_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(post_ffw_norm_scale, kMustRead));
|
|
}
|
|
if (layer_config.use_qk_norm) {
|
|
func(TENSOR_ARGS(key_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(query_norm_scale, kMustRead));
|
|
}
|
|
|
|
if (layer_config.ff_biases) {
|
|
func(TENSOR_ARGS(ffw_gating_biases, kMustRead));
|
|
func(TENSOR_ARGS(ffw_output_biases, kMustRead));
|
|
}
|
|
|
|
if (layer_config.softmax_attn_output_biases &&
|
|
layer_config.type == LayerAttentionType::kGemma) {
|
|
func(TENSOR_ARGS(attention_output_biases, kMustRead));
|
|
}
|
|
} // `ForEachTensor`
|
|
|
|
// Zero-initializes all allocated tensors in the layer.
|
|
void ZeroInit() {
|
|
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::ZeroInit(t.mat);
|
|
});
|
|
}
|
|
|
|
// Must be called after reading weights via `ForEachTensor`.
|
|
// TODO: exporters should bake this into the weights already.
|
|
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
|
void Fixup(std::vector<MatOwner>& mat_owners) {
|
|
InitAttWeights(mat_owners);
|
|
SplitW1();
|
|
SplitAttW1();
|
|
}
|
|
|
|
private:
|
|
// Copies att_weights from `attn_vec_einsum_w`.
|
|
void InitAttWeights(std::vector<MatOwner>& mat_owners) {
|
|
// We only use this tensor for Gemma layers.
|
|
if (layer_config.type != LayerAttentionType::kGemma) return;
|
|
|
|
// Files must have one or the other.
|
|
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
|
|
// Done if we already read the transposed tensor.
|
|
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
|
|
|
|
// NUQ is handled by a specialization in weights.cc.
|
|
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
|
|
|
|
const size_t model_dim = layer_config.model_dim;
|
|
const size_t heads = layer_config.heads;
|
|
const size_t qkv_dim = layer_config.qkv_dim;
|
|
|
|
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
|
|
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
|
|
HWY_ASSERT(att_weights.Rows() == model_dim);
|
|
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
|
|
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
|
|
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
|
|
|
|
{
|
|
static std::mutex m;
|
|
std::lock_guard<std::mutex> lock(m);
|
|
mat_owners.push_back(MatOwner());
|
|
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
|
|
}
|
|
|
|
const size_t T_bytes = att_weights.ElementBytes();
|
|
for (size_t m = 0; m < model_dim; ++m) {
|
|
uint8_t* HWY_RESTRICT out_row =
|
|
reinterpret_cast<uint8_t*>(att_weights.Row(m));
|
|
for (size_t h = 0; h < heads; ++h) {
|
|
hwy::CopyBytes(attn_vec_einsum_w.Row(h * model_dim + m),
|
|
out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes);
|
|
}
|
|
}
|
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
|
}
|
|
|
|
// For FFN. Fast, only updates pointers.
|
|
void SplitW1() {
|
|
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
|
|
if (layer_config.type == LayerAttentionType::kVit) return;
|
|
|
|
// Files have both or neither of w1 and w2.
|
|
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
|
// w is mutually exclusive with w1 and w2 in the file.
|
|
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
|
|
// Done if we already read split tensors. Note that they are not
|
|
// necessarily the same type.
|
|
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
|
|
|
|
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
|
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
|
|
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
|
|
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
|
|
// Cols are the model_dim but we don't have ModelConfig here.
|
|
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
|
|
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
|
|
|
|
const size_t stride = gating_einsum_w.Stride();
|
|
gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride);
|
|
gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride);
|
|
gating_einsum_w1.SetType(gating_einsum_w.GetType());
|
|
gating_einsum_w2.SetType(gating_einsum_w.GetType());
|
|
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
|
|
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
|
|
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
|
|
}
|
|
|
|
// For attention, which might not have a w2. Fast, only updates pointers.
|
|
void SplitAttW1() {
|
|
// We only use this tensor for Gemma layers.
|
|
if (layer_config.type != LayerAttentionType::kGemma) return;
|
|
|
|
// w is mutually exclusive with w1 in the file.
|
|
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
|
|
// Done if we already read split tensors. Note that w2 does not exist for
|
|
// MHA, and otherwise might not be the same type.
|
|
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
|
|
|
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
|
|
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
|
|
|
|
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
|
|
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
|
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
|
|
// Cols are the model_dim but we don't have ModelConfig here.
|
|
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
|
|
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
|
|
|
|
const size_t stride = qkv_einsum_w.Stride();
|
|
qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride);
|
|
qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride);
|
|
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
|
|
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
|
|
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
|
|
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
|
|
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
|
|
}
|
|
};
|
|
|
|
// Holds layer-independent weight metadata and pointers plus per-layer
|
|
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. As with
|
|
// `LayerWeightsPtrs`, this class could be type-erased: member functions do not
|
|
// actually use the `Weight` template argument. The template does allow user
|
|
// code to dispatch only once. However, most tensors are large enough that
|
|
// dispatch at each usage would be feasible.
|
|
// TODO: move `gemma-inl.h` toward dispatch at each usage.
|
|
// TODO: rename to WeightsPtrs.
|
|
template <class Weight>
|
|
struct ModelWeightsPtrs {
|
|
using WeightT = Weight;
|
|
|
|
explicit ModelWeightsPtrs(const ModelConfig& config)
|
|
: tensors_(config),
|
|
// No suffix, these are per-model.
|
|
embedder_input_embedding("c_embedding", tensors_),
|
|
final_norm_scale("c_final_norm", tensors_),
|
|
vit_encoder_norm_bias("enc_norm_bias", tensors_),
|
|
vit_encoder_norm_scale("enc_norm_scale", tensors_),
|
|
vit_img_embedding_bias("img_emb_bias", tensors_),
|
|
vit_img_embedding_kernel("img_emb_kernel", tensors_),
|
|
vit_img_pos_embedding("img_pos_emb", tensors_),
|
|
vit_img_head_bias("img_head_bias", tensors_),
|
|
vit_img_head_kernel("img_head_kernel", tensors_),
|
|
mm_embed_norm("mm_embed_norm", tensors_),
|
|
weights_config(config) {
|
|
c_layers.reserve(config.layer_configs.size());
|
|
for (size_t idx = 0; idx < config.layer_configs.size(); ++idx) {
|
|
const LayerConfig& layer_config = config.layer_configs[idx];
|
|
c_layers.emplace_back(idx, layer_config, tensors_);
|
|
}
|
|
for (size_t idx = 0; idx < config.vit_config.layer_configs.size(); ++idx) {
|
|
const LayerConfig& layer_config = config.vit_config.layer_configs[idx];
|
|
vit_layers.emplace_back(idx, layer_config, tensors_);
|
|
}
|
|
}
|
|
|
|
~ModelWeightsPtrs() = default;
|
|
// = F32 if weights are F32, else BF16.
|
|
using WeightF32OrBF16 = typename LayerWeightsPtrs<Weight>::WeightF32OrBF16;
|
|
|
|
// Passed to all `MatPtrT` initializers, hence must be initialized first.
|
|
const TensorInfoRegistry tensors_;
|
|
|
|
// TODO: switch to SFP?
|
|
MatPtrT<WeightF32OrBF16> embedder_input_embedding;
|
|
MatPtrT<WeightF32OrBF16> final_norm_scale;
|
|
|
|
// Vit parts.
|
|
MatPtrT<WeightF32OrBF16> vit_encoder_norm_bias;
|
|
MatPtrT<WeightF32OrBF16> vit_encoder_norm_scale;
|
|
MatPtrT<float> vit_img_embedding_bias;
|
|
MatPtrT<WeightF32OrBF16> vit_img_embedding_kernel;
|
|
MatPtrT<float> vit_img_pos_embedding;
|
|
// The head maps from VitConfig::model_dim (Vit final layer) to
|
|
// model_dim (LLM input).
|
|
MatPtrT<float> vit_img_head_bias;
|
|
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
|
|
|
MatPtrT<WeightF32OrBF16> mm_embed_norm;
|
|
|
|
const ModelConfig& weights_config;
|
|
|
|
std::vector<LayerWeightsPtrs<Weight>> c_layers;
|
|
std::vector<LayerWeightsPtrs<Weight>> vit_layers;
|
|
|
|
const LayerWeightsPtrs<Weight>* GetLayer(size_t layer) const {
|
|
return &c_layers[layer];
|
|
}
|
|
LayerWeightsPtrs<Weight>* GetLayer(size_t layer) { return &c_layers[layer]; }
|
|
const LayerWeightsPtrs<Weight>* VitLayer(size_t layer) const {
|
|
return &vit_layers[layer];
|
|
}
|
|
LayerWeightsPtrs<Weight>* VitLayer(size_t layer) {
|
|
return &vit_layers[layer];
|
|
}
|
|
|
|
// Called via `CallT`. `other1` and `other2` are usually null, but can be
|
|
// used to copy from another set of weights. Public because called by tests
|
|
// and `WeightsOwner`.
|
|
template <class Func>
|
|
void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
|
|
ModelWeightsPtrs<Weight>* other2, Func func) {
|
|
LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
|
LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
|
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
|
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
|
|
|
if (!weights_config.vit_config.layer_configs.empty()) { // Vit parts.
|
|
func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_embedding_kernel, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_pos_embedding, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_head_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_head_kernel, kMustRead));
|
|
|
|
if (weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
|
func(TENSOR_ARGS(mm_embed_norm, kMustRead));
|
|
}
|
|
}
|
|
|
|
for (size_t layer_idx = 0; layer_idx < c_layers.size(); ++layer_idx) {
|
|
if (other1) other_layer1 = other1->GetLayer(layer_idx);
|
|
if (other2) other_layer2 = other2->GetLayer(layer_idx);
|
|
GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
|
}
|
|
|
|
HWY_ASSERT(weights_config.vit_config.layer_configs.empty() ==
|
|
vit_layers.empty());
|
|
for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) {
|
|
HWY_ASSERT(vit_layers[layer_idx].layer_config.type ==
|
|
LayerAttentionType::kVit);
|
|
other_layer1 = other1 ? other1->VitLayer(layer_idx) : nullptr;
|
|
other_layer2 = other2 ? other2->VitLayer(layer_idx) : nullptr;
|
|
VitLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
|
}
|
|
} // `ForEachTensor`
|
|
|
|
// Zero-initializes only the allocated tensors in `*this`.
|
|
void ZeroInit() {
|
|
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::ZeroInit(t.mat);
|
|
});
|
|
}
|
|
|
|
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
|
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
|
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
|
|
[](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
|
CopyMat(*t.other_mat1, t.mat);
|
|
});
|
|
}
|
|
|
|
// For reshaping file tensors to the shape expected by the code. This would
|
|
// ideally already happen in the importer. Must be called after reading and
|
|
// updating the attention weights.
|
|
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool) {
|
|
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
|
GetLayer(layer)->Fixup(mat_owners);
|
|
});
|
|
|
|
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
|
VitLayer(layer)->Fixup(mat_owners);
|
|
});
|
|
}
|
|
}; // `WeightsPtrs`
|
|
#undef TENSOR_ARGS
|
|
|
|
// Type-erased facade for `WeightsPtrs<T>`, stored inside the non-template
|
|
// `Gemma`. Also owns the underlying memory.
|
|
class WeightsOwner {
|
|
public:
|
|
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
|
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
|
|
|
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
|
// override for whether to map blobs or read them.
|
|
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
|
|
hwy::ThreadPool& pool);
|
|
|
|
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
|
// calls `ForEachTensor`.
|
|
template <class Func, typename... TArgs>
|
|
decltype(auto) CallT(const Func& func, TArgs&&... args) const {
|
|
if (HWY_LIKELY(weight_type_ == Type::kSFP)) {
|
|
return func(sfp_weights_, std::forward<TArgs>(args)...);
|
|
} else if (weight_type_ == Type::kNUQ) {
|
|
return func(nuq_weights_, std::forward<TArgs>(args)...);
|
|
} else if (weight_type_ == Type::kBF16) {
|
|
return func(bf16_weights_, std::forward<TArgs>(args)...);
|
|
}
|
|
return HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
|
|
}
|
|
|
|
// For writers:
|
|
|
|
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
|
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
|
|
|
private:
|
|
Type weight_type_;
|
|
|
|
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
|
// of `CallT` so that can be const.
|
|
void AllocatePointer(const ModelConfig& config);
|
|
|
|
// Called by `ReadFromBlobs`.
|
|
void Fixup(hwy::ThreadPool& pool);
|
|
|
|
// Only one is non-null, determined by `weight_type_`.
|
|
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
|
|
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
|
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
|
|
|
// Owns the memory referenced by all `MatPtr`.
|
|
std::vector<MatOwner> mat_owners_;
|
|
};
|
|
|
|
} // namespace gcpp
|
|
|
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|