mirror of https://github.com/google/gemma.cpp.git
583 lines
24 KiB
C++
583 lines
24 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <complex>
|
|
#include <memory>
|
|
#include <random>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "compression/blob_store.h" // BlobWriter
|
|
#include "compression/shared.h" // IsF32
|
|
#include "gemma/configs.h" // ModelConfig
|
|
#include "gemma/model_store.h" // ModelStore
|
|
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
|
#include "util/mat.h" // MatPtr
|
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
|
|
namespace gcpp {
|
|
|
|
// Argument passed to the `ForEachTensor` callback.
|
|
struct TensorArgs {
|
|
// `other_mat1` and `other_mat2` can be nullptr, or tensor(s) of the same
|
|
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
|
|
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
|
|
// `flags` is a combination of zero or more `Flags`.
|
|
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
|
|
int flags)
|
|
: mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) {
|
|
// Does not make sense to combine both flags.
|
|
HWY_ASSERT(flags != (kMaybeRead | kOnlyAllocate));
|
|
}
|
|
|
|
MatPtr& mat;
|
|
const MatPtr* other_mat1; // either/both can be nullptr.
|
|
const MatPtr* other_mat2;
|
|
|
|
// TODO: freestanding enum class instead? These are mutually exclusive.
|
|
enum Flags {
|
|
// Read the tensor from the file and abort if it is not found.
|
|
kMustRead = 0,
|
|
// Not an error if the tensor is not present in the file. For example,
|
|
// the _w1/_w2 tensors are not always present.
|
|
kMaybeRead = 1,
|
|
// Do not attempt to read, just allocate the tensor. Used for `Reshape`.
|
|
kOnlyAllocate = 2,
|
|
};
|
|
const int flags;
|
|
};
|
|
|
|
// Shorthand for creating the argument to the `ForEachTensor` callback. A macro
|
|
// seems less bad than member pointer syntax.
|
|
#define TENSOR_ARGS(mat, flag) \
|
|
TensorArgs(mat, other1 ? &other1->mat : nullptr, \
|
|
other2 ? &other2->mat : nullptr, TensorArgs::flag)
|
|
|
|
// Per-layer weight metadata and pointers. The tensor data is owned by
|
|
// `WeightsOwner`. Note that this class could be type-erased: member functions
|
|
// do not actually use the `Weight` template argument. See `WeightsPtrs`.
|
|
// `TensorInfoRegistry` (constructed from `ModelConfig`) is the source of truth
|
|
// for all tensor shapes.
|
|
template <class Weight>
|
|
struct LayerWeightsPtrs {
|
|
static inline std::string Concat(const char* base_name,
|
|
const std::string& suffix) {
|
|
return std::string(base_name) + suffix;
|
|
}
|
|
|
|
// Initializes tensor metadata without allocating.
|
|
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
|
const TensorInfoRegistry& tensors)
|
|
: suffix_(LayerSuffix(layer_idx)),
|
|
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
|
|
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
|
|
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
|
|
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
|
|
attention_output_biases(Concat("attn_ob", suffix_), tensors),
|
|
griffin(
|
|
{.linear_x_w = {Concat("gr_lin_x_w", suffix_), tensors},
|
|
.linear_x_biases = {Concat("gr_lin_x_b", suffix_), tensors},
|
|
.linear_y_w = {Concat("gr_lin_y_w", suffix_), tensors},
|
|
.linear_y_biases = {Concat("gr_lin_y_b", suffix_), tensors},
|
|
.linear_out_w = {Concat("gr_lin_out_w", suffix_), tensors},
|
|
.linear_out_biases = {Concat("gr_lin_out_b", suffix_), tensors},
|
|
.conv_w = {Concat("gr_conv_w", suffix_), tensors},
|
|
.conv_biases = {Concat("gr_conv_b", suffix_), tensors},
|
|
.gate_w = {Concat("gr_gate_w", suffix_), tensors},
|
|
.gate_biases = {Concat("gr_gate_b", suffix_), tensors},
|
|
.a = {Concat("gr_a", suffix_), tensors}}),
|
|
// MultiHeadDotProductAttention.
|
|
vit({.attn_out_w = {Concat("attn_out_w", suffix_), tensors},
|
|
.attn_out_b = {Concat("attn_out_b", suffix_), tensors},
|
|
.qkv_einsum_w = {Concat("qkv_ein_w", suffix_), tensors},
|
|
.qkv_einsum_b = {Concat("qkv_ein_b", suffix_), tensors},
|
|
.linear_0_w = {Concat("linear_0_w", suffix_), tensors},
|
|
.linear_0_b = {Concat("linear_0_b", suffix_), tensors},
|
|
.linear_1_w = {Concat("linear_1_w", suffix_), tensors},
|
|
.linear_1_b = {Concat("linear_1_b", suffix_), tensors},
|
|
.layer_norm_0_bias = {Concat("ln_0_bias", suffix_), tensors},
|
|
.layer_norm_0_scale = {Concat("ln_0_scale", suffix_), tensors},
|
|
.layer_norm_1_bias = {Concat("ln_1_bias", suffix_), tensors},
|
|
.layer_norm_1_scale = {Concat("ln_1_scale", suffix_), tensors}}),
|
|
gating_einsum_w(Concat("gating_ein", suffix_), tensors),
|
|
gating_einsum_w1(Concat("gating1_w", suffix_), tensors),
|
|
gating_einsum_w2(Concat("gating2_w", suffix_), tensors),
|
|
linear_w(Concat("linear_w", suffix_), tensors),
|
|
pre_attention_norm_scale(Concat("pre_att_ns", suffix_), tensors),
|
|
pre_ffw_norm_scale(Concat("pre_ff_ns", suffix_), tensors),
|
|
post_attention_norm_scale(Concat("post_att_ns", suffix_), tensors),
|
|
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
|
|
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
|
|
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
|
|
att_weights(Concat("att_w", suffix_), tensors),
|
|
key_norm_scale(Concat("key_norm", suffix_), tensors),
|
|
query_norm_scale(Concat("query_norm", suffix_), tensors),
|
|
layer_config(config) {}
|
|
~LayerWeightsPtrs() = default;
|
|
|
|
const std::string suffix_;
|
|
|
|
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
|
// do not yet support smaller compressed types, or require at least bf16. When
|
|
// weights are f32, we also want such tensors to be f32.
|
|
// If weights are complex, this is also complex.
|
|
using WeightF32OrBF16 =
|
|
hwy::If<hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
|
|
hwy::If<hwy::IsSame<Weight, double>(), double,
|
|
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
|
|
|
MatPtrT<Weight> attn_vec_einsum_w;
|
|
// qkv_einsum_w holds 2 different matrices, which may be separated out.
|
|
// On reading, which is used depends on what is in the file.
|
|
// At inference, the one with a non-null ptr is used.
|
|
MatPtrT<Weight> qkv_einsum_w;
|
|
MatPtrT<Weight> qkv_einsum_w1;
|
|
MatPtrT<Weight> qkv_einsum_w2;
|
|
MatPtrT<float> attention_output_biases;
|
|
|
|
struct {
|
|
MatPtrT<Weight> linear_x_w;
|
|
MatPtrT<float> linear_x_biases;
|
|
MatPtrT<Weight> linear_y_w;
|
|
MatPtrT<float> linear_y_biases;
|
|
MatPtrT<Weight> linear_out_w;
|
|
MatPtrT<float> linear_out_biases;
|
|
MatPtrT<float> conv_w;
|
|
MatPtrT<float> conv_biases;
|
|
MatPtrT<Weight> gate_w;
|
|
MatPtrT<float> gate_biases;
|
|
MatPtrT<float> a;
|
|
} griffin;
|
|
|
|
struct {
|
|
// MultiHeadDotProductAttention.
|
|
MatPtrT<WeightF32OrBF16> attn_out_w;
|
|
MatPtrT<float> attn_out_b;
|
|
MatPtrT<WeightF32OrBF16> qkv_einsum_w;
|
|
MatPtrT<float> qkv_einsum_b;
|
|
// MlpBlock.
|
|
MatPtrT<WeightF32OrBF16> linear_0_w;
|
|
MatPtrT<float> linear_0_b;
|
|
MatPtrT<WeightF32OrBF16> linear_1_w;
|
|
MatPtrT<float> linear_1_b;
|
|
// LayerNorm.
|
|
MatPtrT<WeightF32OrBF16> layer_norm_0_bias;
|
|
MatPtrT<WeightF32OrBF16> layer_norm_0_scale;
|
|
MatPtrT<WeightF32OrBF16> layer_norm_1_bias;
|
|
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
|
|
} vit;
|
|
|
|
// gating_einsum_w holds 2 different matrices, which may be separated out.
|
|
// On reading, which is used depends on what is in the file.
|
|
// At inference, the one with a non-null ptr is used.
|
|
MatPtrT<Weight> gating_einsum_w;
|
|
MatPtrT<Weight> gating_einsum_w1;
|
|
MatPtrT<Weight> gating_einsum_w2;
|
|
MatPtrT<Weight> linear_w;
|
|
// We don't yet have an RMSNorm that accepts all Weight.
|
|
MatPtrT<WeightF32OrBF16> pre_attention_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> pre_ffw_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> post_attention_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> post_ffw_norm_scale;
|
|
|
|
MatPtrT<float> ffw_gating_biases;
|
|
MatPtrT<float> ffw_output_biases;
|
|
|
|
MatPtrT<Weight> att_weights; // For Reshape(); kOnlyAllocate.
|
|
|
|
MatPtrT<WeightF32OrBF16> key_norm_scale;
|
|
MatPtrT<WeightF32OrBF16> query_norm_scale;
|
|
|
|
const LayerConfig& layer_config;
|
|
|
|
// Calls `func(TensorArgs)` for each tensor which is in use for the
|
|
// current `layer_config`. `other1` and `other2` are optional arguments so we
|
|
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
|
// Public because also called by `WeightsPtrs`.
|
|
template <class Func>
|
|
void ForEachTensor(const LayerWeightsPtrs<Weight>* other1,
|
|
const LayerWeightsPtrs<Weight>* other2, Func func) {
|
|
if (layer_config.type == LayerAttentionType::kVit) {
|
|
// MHA.
|
|
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.attn_out_b, kMustRead));
|
|
func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead));
|
|
// MlpBlock.
|
|
func(TENSOR_ARGS(vit.linear_0_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_0_b, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_1_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_1_b, kMustRead));
|
|
// LayerNorm.
|
|
func(TENSOR_ARGS(vit.layer_norm_0_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_0_scale, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_1_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_1_scale, kMustRead));
|
|
return;
|
|
}
|
|
if (layer_config.type == LayerAttentionType::kGemma) {
|
|
// Not read, will be filled by Reshape() from `attn_vec_einsum_w`.
|
|
func(TENSOR_ARGS(att_weights, kOnlyAllocate));
|
|
func(TENSOR_ARGS(attn_vec_einsum_w, kMustRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w, kMustRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
|
|
} else {
|
|
func(TENSOR_ARGS(griffin.linear_x_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_y_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.conv_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.gate_w, kMustRead));
|
|
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
|
|
func(TENSOR_ARGS(griffin.a, kMustRead));
|
|
}
|
|
{
|
|
func(TENSOR_ARGS(gating_einsum_w, kMustRead));
|
|
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
|
|
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
|
|
func(TENSOR_ARGS(linear_w, kMustRead));
|
|
func(TENSOR_ARGS(pre_attention_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(pre_ffw_norm_scale, kMustRead));
|
|
}
|
|
|
|
if (layer_config.post_norm == PostNormType::Scale) {
|
|
func(TENSOR_ARGS(post_attention_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(post_ffw_norm_scale, kMustRead));
|
|
}
|
|
if (layer_config.use_qk_norm) {
|
|
func(TENSOR_ARGS(key_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(query_norm_scale, kMustRead));
|
|
}
|
|
|
|
if (layer_config.ff_biases) {
|
|
func(TENSOR_ARGS(ffw_gating_biases, kMustRead));
|
|
func(TENSOR_ARGS(ffw_output_biases, kMustRead));
|
|
}
|
|
|
|
if (layer_config.softmax_attn_output_biases &&
|
|
layer_config.type == LayerAttentionType::kGemma) {
|
|
func(TENSOR_ARGS(attention_output_biases, kMustRead));
|
|
}
|
|
} // `ForEachTensor`
|
|
|
|
// Zero-initializes all allocated tensors in the layer.
|
|
void ZeroInit() {
|
|
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::ZeroInit(t.mat);
|
|
});
|
|
}
|
|
|
|
void RandInit(float stddev, std::mt19937& gen) {
|
|
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::RandInit(t.mat, stddev, gen);
|
|
});
|
|
}
|
|
|
|
// Allocates memory for all the tensors in the layer. Note that this is slow
|
|
// (non-parallel) and only used for a stand-alone layer.
|
|
void AllocateForTest(MatOwners& mat_owners) {
|
|
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
|
// `backprop/` does not use row accessors and hence requires kPacked.
|
|
mat_owners.AllocateFor(t.mat, MatPadding::kPacked);
|
|
});
|
|
}
|
|
|
|
// Initializes att_weights from `attn_vec_einsum_w`, hence this must be called
|
|
// after reading weights via `ForEachTensor`.
|
|
// TODO: update compression/convert_weights to bake this in.
|
|
void Reshape() {
|
|
// NUQ is handled by a specialization in weights.cc.
|
|
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
|
|
|
|
const size_t model_dim = layer_config.model_dim;
|
|
const size_t heads = layer_config.heads;
|
|
const size_t qkv_dim = layer_config.qkv_dim;
|
|
|
|
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
|
|
HWY_ASSERT(att_weights.HasPtr());
|
|
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
|
|
HWY_ASSERT(att_weights.Rows() == model_dim);
|
|
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
|
|
HWY_ASSERT(attn_vec_einsum_w.HasPtr());
|
|
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
|
|
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
|
|
const size_t T_bytes = att_weights.ElementBytes();
|
|
for (size_t m = 0; m < model_dim; ++m) {
|
|
uint8_t* HWY_RESTRICT out_row =
|
|
reinterpret_cast<uint8_t*>(att_weights.Row(m));
|
|
for (size_t h = 0; h < heads; ++h) {
|
|
hwy::CopyBytes(attn_vec_einsum_w.Row(h * model_dim + m),
|
|
out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes);
|
|
}
|
|
}
|
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
|
}
|
|
};
|
|
|
|
// Holds layer-independent weight metadata and pointers plus per-layer
|
|
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. As with
|
|
// `LayerWeightsPtrs`, this class could be type-erased: member functions do not
|
|
// actually use the `Weight` template argument. The template does allow user
|
|
// code to dispatch only once. However, most tensors are large enough that
|
|
// dispatch at each usage would be feasible.
|
|
// TODO: move `gemma-inl.h` toward dispatch at each usage.
|
|
// TODO: rename to WeightsPtrs.
|
|
template <class Weight>
|
|
struct ModelWeightsPtrs {
|
|
using WeightT = Weight;
|
|
|
|
explicit ModelWeightsPtrs(const ModelConfig& config)
|
|
: tensors_(config),
|
|
// No suffix, these are per-model.
|
|
embedder_input_embedding("c_embedding", tensors_),
|
|
final_norm_scale("c_final_norm", tensors_),
|
|
vit_encoder_norm_bias("enc_norm_bias", tensors_),
|
|
vit_encoder_norm_scale("enc_norm_scale", tensors_),
|
|
vit_img_embedding_bias("img_emb_bias", tensors_),
|
|
vit_img_embedding_kernel("img_emb_kernel", tensors_),
|
|
vit_img_pos_embedding("img_pos_emb", tensors_),
|
|
vit_img_head_bias("img_head_bias", tensors_),
|
|
vit_img_head_kernel("img_head_kernel", tensors_),
|
|
mm_embed_norm("mm_embed_norm", tensors_),
|
|
weights_config(config) {
|
|
c_layers.reserve(config.layer_configs.size());
|
|
for (size_t idx = 0; idx < config.layer_configs.size(); ++idx) {
|
|
const LayerConfig& layer_config = config.layer_configs[idx];
|
|
c_layers.emplace_back(idx, layer_config, tensors_);
|
|
}
|
|
for (size_t idx = 0; idx < config.vit_config.layer_configs.size(); ++idx) {
|
|
const LayerConfig& layer_config = config.vit_config.layer_configs[idx];
|
|
vit_layers.emplace_back(idx, layer_config, tensors_);
|
|
}
|
|
}
|
|
|
|
~ModelWeightsPtrs() = default;
|
|
// = F32 if weights are F32, else BF16.
|
|
using WeightF32OrBF16 = typename LayerWeightsPtrs<Weight>::WeightF32OrBF16;
|
|
|
|
// Passed to all `MatPtrT` initializers, hence must be initialized first.
|
|
const TensorInfoRegistry tensors_;
|
|
|
|
// TODO: switch to SFP?
|
|
MatPtrT<WeightF32OrBF16> embedder_input_embedding;
|
|
MatPtrT<WeightF32OrBF16> final_norm_scale;
|
|
|
|
// Vit parts.
|
|
MatPtrT<WeightF32OrBF16> vit_encoder_norm_bias;
|
|
MatPtrT<WeightF32OrBF16> vit_encoder_norm_scale;
|
|
MatPtrT<float> vit_img_embedding_bias;
|
|
MatPtrT<WeightF32OrBF16> vit_img_embedding_kernel;
|
|
MatPtrT<float> vit_img_pos_embedding;
|
|
// The head maps from VitConfig::model_dim (Vit final layer) to
|
|
// model_dim (LLM input).
|
|
MatPtrT<float> vit_img_head_bias;
|
|
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
|
|
|
MatPtrT<WeightF32OrBF16> mm_embed_norm;
|
|
|
|
const ModelConfig& weights_config;
|
|
|
|
std::vector<LayerWeightsPtrs<Weight>> c_layers;
|
|
std::vector<LayerWeightsPtrs<Weight>> vit_layers;
|
|
|
|
const LayerWeightsPtrs<Weight>* GetLayer(size_t layer) const {
|
|
return &c_layers[layer];
|
|
}
|
|
LayerWeightsPtrs<Weight>* GetLayer(size_t layer) { return &c_layers[layer]; }
|
|
const LayerWeightsPtrs<Weight>* VitLayer(size_t layer) const {
|
|
return &vit_layers[layer];
|
|
}
|
|
LayerWeightsPtrs<Weight>* VitLayer(size_t layer) {
|
|
return &vit_layers[layer];
|
|
}
|
|
|
|
// Called via `CallT`. `other1` and `other2` are usually null, but can be
|
|
// used to copy from another set of weights. Public because called by tests
|
|
// and `WeightsOwner`.
|
|
template <class Func>
|
|
void ForEachTensor(const ModelWeightsPtrs<Weight>* other1,
|
|
const ModelWeightsPtrs<Weight>* other2, Func func) {
|
|
const LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
|
|
const LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
|
|
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
|
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
|
|
|
if (!weights_config.vit_config.layer_configs.empty()) { // Vit parts.
|
|
func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_embedding_kernel, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_pos_embedding, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_head_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_head_kernel, kMustRead));
|
|
|
|
if (weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
|
func(TENSOR_ARGS(mm_embed_norm, kMustRead));
|
|
}
|
|
}
|
|
|
|
for (size_t layer_idx = 0; layer_idx < c_layers.size(); ++layer_idx) {
|
|
if (other1) other_layer1 = other1->GetLayer(layer_idx);
|
|
if (other2) other_layer2 = other2->GetLayer(layer_idx);
|
|
GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
|
}
|
|
|
|
HWY_ASSERT(weights_config.vit_config.layer_configs.empty() ==
|
|
vit_layers.empty());
|
|
for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) {
|
|
HWY_ASSERT(vit_layers[layer_idx].layer_config.type ==
|
|
LayerAttentionType::kVit);
|
|
other_layer1 = other1 ? other1->VitLayer(layer_idx) : nullptr;
|
|
other_layer2 = other2 ? other2->VitLayer(layer_idx) : nullptr;
|
|
VitLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
|
}
|
|
} // `ForEachTensor`
|
|
|
|
// Zero-initializes only the allocated tensors in `*this`.
|
|
void ZeroInit() {
|
|
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::ZeroInit(t.mat);
|
|
});
|
|
}
|
|
|
|
void RandInit(float stddev, std::mt19937& gen) {
|
|
ForEachTensor(nullptr, nullptr, [stddev, &gen](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::RandInit(t.mat, stddev, gen);
|
|
});
|
|
}
|
|
|
|
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
|
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
|
ForEachTensor(&other, nullptr, [](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
|
CopyMat(*t.other_mat1, t.mat);
|
|
});
|
|
}
|
|
|
|
// Instead of reading, only allocates memory for all tensors. Used by
|
|
// `optimizer.cc` via the `Gemma` constructor without weights.
|
|
void AllocateForTest(MatOwners& mat_owners, hwy::ThreadPool& pool) {
|
|
// First get a list of all the tensors.
|
|
std::vector<MatPtr*> all_mat;
|
|
all_mat.reserve(10 * c_layers.size());
|
|
ForEachTensor(nullptr, nullptr, [&all_mat](const TensorArgs& t) {
|
|
all_mat.push_back(&t.mat);
|
|
});
|
|
|
|
// `backprop/` does not use row accessors and hence requires kPacked.
|
|
mat_owners.AllocateFor(all_mat, MatPadding::kPacked, pool);
|
|
}
|
|
|
|
// For reshaping file tensors to the shape expected by the code. This would
|
|
// ideally already happen in the importer. Must be called after reading and
|
|
// updating the attention weights.
|
|
void Reshape(hwy::ThreadPool& pool) {
|
|
pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
|
|
GetLayer(layer)->Reshape();
|
|
});
|
|
|
|
pool.Run(0, vit_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
|
|
VitLayer(layer)->Reshape();
|
|
});
|
|
}
|
|
}; // `WeightsPtrs`
|
|
#undef TENSOR_ARGS
|
|
|
|
// Type-erased facade for `WeightsPtrs<T>`, stored inside the non-template
|
|
// `Gemma`. Also owns the underlying memory.
|
|
class WeightsOwner {
|
|
public:
|
|
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
|
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
|
|
|
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`,
|
|
// allocates memory and reshapes. Aborts on error.
|
|
void ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
|
hwy::ThreadPool& pool);
|
|
|
|
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
|
// calls `ForEachTensor`.
|
|
template <class Func, typename... TArgs>
|
|
decltype(auto) CallT(const Func& func, TArgs&&... args) const {
|
|
if (HWY_LIKELY(weight_type_ == Type::kSFP)) {
|
|
return func(sfp_weights_, std::forward<TArgs>(args)...);
|
|
} else if (weight_type_ == Type::kNUQ) {
|
|
return func(nuq_weights_, std::forward<TArgs>(args)...);
|
|
} else if (weight_type_ == Type::kF32) {
|
|
return func(float_weights_, std::forward<TArgs>(args)...);
|
|
} else if (weight_type_ == Type::kBF16) {
|
|
return func(bf16_weights_, std::forward<TArgs>(args)...);
|
|
}
|
|
return HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
|
|
}
|
|
|
|
// For writers:
|
|
|
|
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
|
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
|
|
|
// For backprop/:
|
|
|
|
// Only allocates; must subsequently call `ZeroInit` or `RandInit`.
|
|
void AllocateForTest(const ModelConfig& config, hwy::ThreadPool& pool);
|
|
void ZeroInit();
|
|
void RandInit(float stddev, std::mt19937& gen); // F32 or F64 only.
|
|
void LogWeightStatsF32();
|
|
|
|
ModelWeightsPtrs<float>* GetF32() const {
|
|
HWY_ASSERT(weight_type_ == Type::kF32);
|
|
return float_weights_.get();
|
|
}
|
|
|
|
// Usually taken care of by `ReadOrAllocate`, but must also be called by
|
|
// `optimize_test, which updates the attention weights from which this copies.
|
|
void Reshape(hwy::ThreadPool& pool);
|
|
|
|
private:
|
|
Type weight_type_;
|
|
|
|
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
|
// of `CallT` so that can be const.
|
|
void AllocatePointer(const ModelConfig& config);
|
|
|
|
// Only one is non-null, determined by `weight_type_`.
|
|
std::unique_ptr<ModelWeightsPtrs<float>> float_weights_;
|
|
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
|
|
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
|
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
|
|
|
// Owns the memory referenced by all `MatPtr`.
|
|
MatOwners mat_owners_;
|
|
};
|
|
|
|
} // namespace gcpp
|
|
|
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|