mirror of https://github.com/google/gemma.cpp.git
429 lines
16 KiB
C++
429 lines
16 KiB
C++
// Copyright 2024 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "compression/types.h"
|
|
#include "gemma/configs.h" // ModelConfig
|
|
#include "gemma/gemma_args.h" // InferenceArgs
|
|
#include "gemma/model_store.h" // ModelStore
|
|
#include "gemma/tensor_info.h" // TensorInfoRegistry
|
|
#include "io/blob_store.h" // BlobWriter
|
|
#include "util/mat.h" // MatPtr
|
|
#include "util/threading_context.h"
|
|
|
|
namespace gcpp {
|
|
|
|
// Argument passed to the `ForEachTensor` callback.
|
|
struct TensorArgs {
|
|
// `other_mat1` and `other_mat2` can be nullptr, or tensor(s) of the same
|
|
// name/type from another `LayerWeightsPtrs` for iterating over tensor pairs
|
|
// (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`.
|
|
// `flags` is a combination of zero or more `Flags`.
|
|
TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags)
|
|
: mat(mat),
|
|
other_mat1(other_mat1),
|
|
other_mat2(other_mat2),
|
|
flags(flags) {}
|
|
|
|
MatPtr& mat;
|
|
MatPtr* other_mat1; // either/both can be nullptr.
|
|
MatPtr* other_mat2;
|
|
|
|
enum Flags {
|
|
// Default: Read the tensor from the file and abort if it is not found.
|
|
kMustRead = 0,
|
|
|
|
// Not an error if the tensor is not present in the file. For example,
|
|
// the _w1/_w2 tensors are not always present.
|
|
kMaybeRead = 1,
|
|
|
|
// Avoid padding tensor rows when reading.
|
|
kPacked = 2,
|
|
};
|
|
const int flags;
|
|
};
|
|
|
|
// Shorthand for creating the argument to the `ForEachTensor` callback. A macro
|
|
// seems less bad than member pointer syntax.
|
|
#define TENSOR_ARGS(mat, flag) \
|
|
TensorArgs(mat, other1 ? &other1->mat : nullptr, \
|
|
other2 ? &other2->mat : nullptr, TensorArgs::flag)
|
|
|
|
// Finds tensors by name in `TensorInfoRegistry` (constructed from
|
|
// `ModelConfig`) and constructs `MatPtr` metadata with those shapes.
|
|
class MatFinder {
|
|
public:
|
|
MatFinder(const std::string& suffix, const TensorInfoRegistry& tensors)
|
|
: suffix_(suffix), tensors_(tensors) {}
|
|
|
|
// Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`.
|
|
MatPtr operator()(const std::string& base_name) const {
|
|
const std::string name = std::string(base_name) + suffix_;
|
|
return MatPtr(name.c_str(), Type::kUnknown,
|
|
ExtentsFromInfo(tensors_.Find(name)));
|
|
}
|
|
|
|
private:
|
|
const std::string suffix_;
|
|
const TensorInfoRegistry& tensors_;
|
|
};
|
|
|
|
// Per-layer weight metadata and pointers. The tensor data is owned by
|
|
// `MatOwner`.
|
|
struct LayerWeightsPtrs {
|
|
// Initializes tensor metadata without allocating.
|
|
// NOTE: do not store layer_idx, TransformerLayer and Attention may use
|
|
// other values for purposes of the KV cache.
|
|
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
|
const TensorInfoRegistry& tensors)
|
|
: layer_idx(layer_idx),
|
|
finder_(LayerSuffix(layer_idx), tensors),
|
|
qkv_einsum_w(finder_("qkv_ein")),
|
|
qkv_einsum_w1(finder_("qkv1_w")),
|
|
qkv_einsum_w2(finder_("qkv2_w")),
|
|
attention_output_biases(finder_("attn_ob")),
|
|
// MultiHeadDotProductAttention.
|
|
vit({.attn_out_w = finder_("attn_out_w"),
|
|
.attn_out_b = finder_("attn_out_b"),
|
|
.qkv_einsum_w = finder_("qkv_ein_w"),
|
|
.qkv_einsum_b = finder_("qkv_ein_b"),
|
|
.linear_0_w = finder_("linear_0_w"),
|
|
.linear_0_b = finder_("linear_0_b"),
|
|
.linear_1_w = finder_("linear_1_w"),
|
|
.linear_1_b = finder_("linear_1_b"),
|
|
.layer_norm_0_bias = finder_("ln_0_bias"),
|
|
.layer_norm_0_scale = finder_("ln_0_scale"),
|
|
.layer_norm_1_bias = finder_("ln_1_bias"),
|
|
.layer_norm_1_scale = finder_("ln_1_scale")}),
|
|
gating_einsum_w(finder_("gating_ein")),
|
|
gating_einsum_w1(finder_("gating1_w")),
|
|
gating_einsum_w2(finder_("gating2_w")),
|
|
linear_w(finder_("linear_w")),
|
|
pre_attention_norm_scale(finder_("pre_att_ns")),
|
|
pre_ffw_norm_scale(finder_("pre_ff_ns")),
|
|
post_attention_norm_scale(finder_("post_att_ns")),
|
|
post_ffw_norm_scale(finder_("post_ff_ns")),
|
|
ffw_gating_biases(finder_("ffw_gat_b")),
|
|
ffw_output_biases(finder_("ffw_out_b")),
|
|
|
|
attn_vec_einsum_w(finder_("att_ein")),
|
|
att_weights(finder_("att_w")),
|
|
|
|
key_norm_scale(finder_("key_norm")),
|
|
query_norm_scale(finder_("query_norm")),
|
|
|
|
layer_config(config) {
|
|
}
|
|
~LayerWeightsPtrs() = default;
|
|
|
|
const size_t layer_idx;
|
|
const MatFinder finder_;
|
|
|
|
// Files either have qkv_einsum_w with 2 stacked matrices or separate
|
|
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
|
MatPtr qkv_einsum_w;
|
|
MatPtr qkv_einsum_w1;
|
|
MatPtr qkv_einsum_w2;
|
|
MatPtrT<float> attention_output_biases;
|
|
|
|
struct {
|
|
// MultiHeadDotProductAttention.
|
|
MatPtr attn_out_w; // at least BF16.
|
|
MatPtrT<float> attn_out_b;
|
|
MatPtr qkv_einsum_w; // at least BF16.
|
|
MatPtrT<float> qkv_einsum_b;
|
|
// MlpBlock.
|
|
MatPtr linear_0_w; // at least BF16.
|
|
MatPtrT<float> linear_0_b;
|
|
MatPtr linear_1_w; // at least BF16.
|
|
MatPtrT<float> linear_1_b;
|
|
// LayerNorm.
|
|
MatPtr layer_norm_0_bias; // at least BF16.
|
|
MatPtr layer_norm_0_scale; // at least BF16.
|
|
MatPtr layer_norm_1_bias; // at least BF16.
|
|
MatPtr layer_norm_1_scale; // at least BF16.
|
|
} vit;
|
|
|
|
// Files either have gating_einsum_w with 2 stacked matrices or separate
|
|
// w1/w2 tensors. `Fixup` ensures w1/w2 are ready for use by gemma-inl.h.
|
|
MatPtr gating_einsum_w;
|
|
MatPtr gating_einsum_w1;
|
|
MatPtr gating_einsum_w2;
|
|
MatPtr linear_w;
|
|
MatPtr pre_attention_norm_scale; // at least BF16.
|
|
MatPtr pre_ffw_norm_scale; // at least BF16.
|
|
MatPtr post_attention_norm_scale; // at least BF16.
|
|
MatPtr post_ffw_norm_scale; // at least BF16.
|
|
|
|
MatPtrT<float> ffw_gating_biases;
|
|
MatPtrT<float> ffw_output_biases;
|
|
|
|
MatPtr attn_vec_einsum_w; // Use att_weights instead of this.
|
|
MatPtr att_weights; // Use this instead of attn_vec_einsum_w.
|
|
|
|
MatPtr key_norm_scale; // at least BF16.
|
|
MatPtr query_norm_scale; // at least BF16.
|
|
|
|
const LayerConfig& layer_config;
|
|
|
|
// Calls `func(TensorArgs)` for each tensor which is in use for the
|
|
// current `layer_config`. `other1` and `other2` are optional arguments so we
|
|
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
|
|
// Public because also called by `WeightsPtrs`.
|
|
template <class Func>
|
|
void ForEachTensor(LayerWeightsPtrs* other1, LayerWeightsPtrs* other2,
|
|
Func func) {
|
|
if (layer_config.type == LayerAttentionType::kVit) {
|
|
// MHA.
|
|
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.attn_out_b, kMustRead));
|
|
func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead));
|
|
// Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence
|
|
// must not be padded.
|
|
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked));
|
|
// MlpBlock.
|
|
func(TENSOR_ARGS(vit.linear_0_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_0_b, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_1_w, kMustRead));
|
|
func(TENSOR_ARGS(vit.linear_1_b, kMustRead));
|
|
// LayerNorm.
|
|
func(TENSOR_ARGS(vit.layer_norm_0_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_0_scale, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_1_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit.layer_norm_1_scale, kMustRead));
|
|
return;
|
|
}
|
|
if (layer_config.type == LayerAttentionType::kGemma) {
|
|
// Either read from file, or allocated during Fixup().
|
|
func(TENSOR_ARGS(att_weights, kMaybeRead));
|
|
func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
|
|
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
|
|
}
|
|
{
|
|
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
|
|
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
|
|
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
|
|
func(TENSOR_ARGS(linear_w, kMaybeRead));
|
|
func(TENSOR_ARGS(pre_attention_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(pre_ffw_norm_scale, kMustRead));
|
|
}
|
|
|
|
if (layer_config.post_norm == PostNormType::Scale) {
|
|
func(TENSOR_ARGS(post_attention_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(post_ffw_norm_scale, kMustRead));
|
|
}
|
|
if (layer_config.use_qk_norm) {
|
|
func(TENSOR_ARGS(key_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(query_norm_scale, kMustRead));
|
|
}
|
|
|
|
if (layer_config.ff_biases) {
|
|
func(TENSOR_ARGS(ffw_gating_biases, kMustRead));
|
|
func(TENSOR_ARGS(ffw_output_biases, kMustRead));
|
|
}
|
|
} // `ForEachTensor`
|
|
|
|
// Zero-initializes all allocated tensors in the layer.
|
|
void ZeroInit() {
|
|
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
|
if (!t.mat.HasPtr()) return;
|
|
gcpp::ZeroInit(t.mat);
|
|
});
|
|
}
|
|
|
|
// Must be called after reading weights via `ForEachTensor`.
|
|
// TODO: exporters should bake this into the weights already.
|
|
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
|
void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
|
|
|
|
private:
|
|
// Copies att_weights from `attn_vec_einsum_w`.
|
|
void InitAttWeights(std::vector<MatOwner>& mat_owners,
|
|
const Allocator& allocator);
|
|
|
|
// For FFN. Fast, only updates pointers.
|
|
void SplitW1();
|
|
|
|
// For attention, which might not have a w2. Fast, only updates pointers.
|
|
void SplitAttW1();
|
|
};
|
|
|
|
// Holds layer-independent weight metadata and pointers plus per-layer
|
|
// `LayerWeightsPtrs`. The tensor data is owned by `MatOwner`.
|
|
struct WeightsPtrs {
|
|
explicit WeightsPtrs(const ModelConfig& config)
|
|
: config_(config),
|
|
tensors_(config_),
|
|
finder_("", tensors_), // no suffix because these are per-model.
|
|
embedder_input_embedding(finder_("c_embedding")),
|
|
final_norm_scale(finder_("c_final_norm")),
|
|
vit_encoder_norm_bias(finder_("enc_norm_bias")),
|
|
vit_encoder_norm_scale(finder_("enc_norm_scale")),
|
|
vit_img_embedding_bias(finder_("img_emb_bias")),
|
|
vit_img_embedding_kernel(finder_("img_emb_kernel")),
|
|
vit_img_pos_embedding(finder_("img_pos_emb")),
|
|
vit_img_head_bias(finder_("img_head_bias")),
|
|
vit_img_head_kernel(finder_("img_head_kernel")),
|
|
mm_embed_norm(finder_("mm_embed_norm")),
|
|
c_layers() {
|
|
c_layers.reserve(config_.layer_configs.size());
|
|
for (size_t idx = 0; idx < config_.layer_configs.size(); ++idx) {
|
|
const LayerConfig& layer_config = config_.layer_configs[idx];
|
|
c_layers.emplace_back(idx, layer_config, tensors_);
|
|
}
|
|
for (size_t idx = 0; idx < config_.vit_config.layer_configs.size(); ++idx) {
|
|
const LayerConfig& layer_config = config_.vit_config.layer_configs[idx];
|
|
vit_layers.emplace_back(idx, layer_config, tensors_);
|
|
}
|
|
}
|
|
|
|
~WeightsPtrs() = default;
|
|
|
|
const ModelConfig& config_;
|
|
// Passed to finder_, hence must be initialized first.
|
|
const TensorInfoRegistry tensors_;
|
|
const MatFinder finder_;
|
|
|
|
// TODO: switch to SFP?
|
|
MatPtr embedder_input_embedding;
|
|
MatPtr final_norm_scale; // at least BF16.
|
|
|
|
// Vit parts.
|
|
MatPtr vit_encoder_norm_bias; // at least BF16.
|
|
MatPtr vit_encoder_norm_scale; // at least BF16.
|
|
MatPtrT<float> vit_img_embedding_bias;
|
|
MatPtr vit_img_embedding_kernel; // at least BF16.
|
|
MatPtr vit_img_pos_embedding; // F32?
|
|
// The head maps from VitConfig::model_dim (Vit final layer) to
|
|
// model_dim (LLM input).
|
|
MatPtrT<float> vit_img_head_bias;
|
|
MatPtr vit_img_head_kernel; // at least BF16.
|
|
|
|
MatPtr mm_embed_norm; // at least BF16.
|
|
|
|
std::vector<LayerWeightsPtrs> c_layers;
|
|
std::vector<LayerWeightsPtrs> vit_layers;
|
|
|
|
const LayerWeightsPtrs* GetLayer(size_t layer) const {
|
|
return &c_layers[layer];
|
|
}
|
|
LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; }
|
|
const LayerWeightsPtrs* VitLayer(size_t layer) const {
|
|
return &vit_layers[layer];
|
|
}
|
|
LayerWeightsPtrs* VitLayer(size_t layer) { return &vit_layers[layer]; }
|
|
|
|
// Called via `CallT`. `other1` and `other2` are usually null, but can be
|
|
// used to copy from another set of weights. Public because called by tests
|
|
// and `WeightsOwner`.
|
|
template <class Func>
|
|
void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) {
|
|
LayerWeightsPtrs* other_layer1 = nullptr;
|
|
LayerWeightsPtrs* other_layer2 = nullptr;
|
|
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
|
func(TENSOR_ARGS(final_norm_scale, kMustRead));
|
|
|
|
if (!config_.vit_config.layer_configs.empty()) { // Vit parts.
|
|
func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_embedding_kernel, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_pos_embedding, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_head_bias, kMustRead));
|
|
func(TENSOR_ARGS(vit_img_head_kernel, kMustRead));
|
|
|
|
if (config_.wrapping == PromptWrapping::GEMMA_VLM) {
|
|
func(TENSOR_ARGS(mm_embed_norm, kMustRead));
|
|
}
|
|
}
|
|
|
|
for (size_t layer_idx = 0; layer_idx < c_layers.size(); ++layer_idx) {
|
|
if (other1) other_layer1 = other1->GetLayer(layer_idx);
|
|
if (other2) other_layer2 = other2->GetLayer(layer_idx);
|
|
GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
|
}
|
|
|
|
HWY_ASSERT(config_.vit_config.layer_configs.empty() == vit_layers.empty());
|
|
for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) {
|
|
HWY_ASSERT(vit_layers[layer_idx].layer_config.type ==
|
|
LayerAttentionType::kVit);
|
|
other_layer1 = other1 ? other1->VitLayer(layer_idx) : nullptr;
|
|
other_layer2 = other2 ? other2->VitLayer(layer_idx) : nullptr;
|
|
VitLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
|
|
}
|
|
} // `ForEachTensor`
|
|
|
|
// Zero-initializes only the allocated tensors in `*this`.
|
|
void ZeroInit();
|
|
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
|
void CopyFrom(const WeightsPtrs& other);
|
|
|
|
enum class Mode {
|
|
// Parallel I/O, decompress to BF16. Best for large batch sizes.
|
|
kReadBF16,
|
|
// Parallel I/O, insert row-wise padding. Safe default.
|
|
kRead,
|
|
// Best for large weights relative to available memory, especially for
|
|
// frequent invocations of small batches and short sequences. Adds noise to
|
|
// performance measurements due to I/O variability.
|
|
kMap
|
|
};
|
|
|
|
static const char* ToString(Mode mode) {
|
|
switch (mode) {
|
|
case Mode::kReadBF16:
|
|
return "ReadBF16";
|
|
case Mode::kRead:
|
|
return "Read";
|
|
case Mode::kMap:
|
|
return "Map";
|
|
default:
|
|
HWY_DASSERT(false);
|
|
return "?";
|
|
}
|
|
}
|
|
|
|
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
|
// override for whether to map blobs or read them. Returns the mode used.
|
|
Mode ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
|
const LoaderArgs& loader, const InferenceArgs& inference,
|
|
std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
|
|
|
|
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
|
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
|
|
|
private:
|
|
// For reshaping file tensors to the shape expected by the code. This would
|
|
// ideally already happen in the importer. Called by ReadFromBlobs.
|
|
void Fixup(std::vector<MatOwner>& mat_owners, ThreadingContext& ctx);
|
|
|
|
MapPtr mapped_;
|
|
}; // `WeightsPtrs`
|
|
#undef TENSOR_ARGS
|
|
|
|
} // namespace gcpp
|
|
|
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|