// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #include #include #include #include #include "compression/types.h" #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // InferenceArgs #include "gemma/model_store.h" // ModelStore #include "gemma/tensor_info.h" // TensorInfoRegistry #include "io/blob_store.h" // BlobWriter #include "util/mat.h" // MatPtr #include "util/threading_context.h" namespace gcpp { // Argument passed to the `ForEachTensor` callback. struct TensorArgs { // `other_mat1` and `other_mat2` can be nullptr, or tensor(s) of the same // name/type from another `LayerWeightsPtrs` for iterating over tensor pairs // (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`. // `flags` is a combination of zero or more `Flags`. TensorArgs(MatPtr& mat, MatPtr* other_mat1, MatPtr* other_mat2, int flags) : mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) {} MatPtr& mat; MatPtr* other_mat1; // either/both can be nullptr. MatPtr* other_mat2; enum Flags { // Default: Read the tensor from the file and abort if it is not found. kMustRead = 0, // Not an error if the tensor is not present in the file. For example, // the _w1/_w2 tensors are not always present. kMaybeRead = 1, // Avoid padding tensor rows when reading. kPacked = 2, }; const int flags; }; // Shorthand for creating the argument to the `ForEachTensor` callback. A macro // seems less bad than member pointer syntax. #define TENSOR_ARGS(mat, flag) \ TensorArgs(mat, other1 ? &other1->mat : nullptr, \ other2 ? &other2->mat : nullptr, TensorArgs::flag) // Finds tensors by name in `TensorInfoRegistry` (constructed from // `ModelConfig`) and constructs `MatPtr` metadata with those shapes. class MatFinder { public: MatFinder(const std::string& suffix, const TensorInfoRegistry& tensors) : suffix_(suffix), tensors_(tensors) {} // Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. MatPtr operator()(const std::string& base_name) const { const std::string name = std::string(base_name) + suffix_; return MatPtr(name.c_str(), Type::kUnknown, ExtentsFromInfo(tensors_.Find(name))); } private: const std::string suffix_; const TensorInfoRegistry& tensors_; }; // Per-layer weight metadata and pointers. The tensor data is owned by // `MatOwner`. struct LayerWeightsPtrs { // Initializes tensor metadata without allocating. // NOTE: do not store layer_idx, TransformerLayer and Attention may use // other values for purposes of the KV cache. LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) : finder_(LayerSuffix(layer_idx), tensors), qkv_einsum_w(finder_("qkv_ein")), qkv_einsum_w1(finder_("qkv1_w")), qkv_einsum_w2(finder_("qkv2_w")), attention_output_biases(finder_("attn_ob")), // MultiHeadDotProductAttention. vit({.attn_out_w = finder_("attn_out_w"), .attn_out_b = finder_("attn_out_b"), .qkv_einsum_w = finder_("qkv_ein_w"), .qkv_einsum_b = finder_("qkv_ein_b"), .linear_0_w = finder_("linear_0_w"), .linear_0_b = finder_("linear_0_b"), .linear_1_w = finder_("linear_1_w"), .linear_1_b = finder_("linear_1_b"), .layer_norm_0_bias = finder_("ln_0_bias"), .layer_norm_0_scale = finder_("ln_0_scale"), .layer_norm_1_bias = finder_("ln_1_bias"), .layer_norm_1_scale = finder_("ln_1_scale")}), gating_einsum_w(finder_("gating_ein")), gating_einsum_w1(finder_("gating1_w")), gating_einsum_w2(finder_("gating2_w")), linear_w(finder_("linear_w")), pre_attention_norm_scale(finder_("pre_att_ns")), pre_ffw_norm_scale(finder_("pre_ff_ns")), post_attention_norm_scale(finder_("post_att_ns")), post_ffw_norm_scale(finder_("post_ff_ns")), ffw_gating_biases(finder_("ffw_gat_b")), ffw_output_biases(finder_("ffw_out_b")), attn_vec_einsum_w(finder_("att_ein")), att_weights(finder_("att_w")), key_norm_scale(finder_("key_norm")), query_norm_scale(finder_("query_norm")), layer_config(config) { } ~LayerWeightsPtrs() = default; const MatFinder finder_; // Files either have qkv_einsum_w with 2 stacked matrices or separate // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h. MatPtr qkv_einsum_w; MatPtr qkv_einsum_w1; MatPtr qkv_einsum_w2; MatPtrT attention_output_biases; struct { // MultiHeadDotProductAttention. MatPtr attn_out_w; // at least BF16. MatPtrT attn_out_b; MatPtr qkv_einsum_w; // at least BF16. MatPtrT qkv_einsum_b; // MlpBlock. MatPtr linear_0_w; // at least BF16. MatPtrT linear_0_b; MatPtr linear_1_w; // at least BF16. MatPtrT linear_1_b; // LayerNorm. MatPtr layer_norm_0_bias; // at least BF16. MatPtr layer_norm_0_scale; // at least BF16. MatPtr layer_norm_1_bias; // at least BF16. MatPtr layer_norm_1_scale; // at least BF16. } vit; // Files either have gating_einsum_w with 2 stacked matrices or separate // w1/w2 tensors. `Fixup` ensures w1/w2 are ready for use by gemma-inl.h. MatPtr gating_einsum_w; MatPtr gating_einsum_w1; MatPtr gating_einsum_w2; MatPtr linear_w; MatPtr pre_attention_norm_scale; // at least BF16. MatPtr pre_ffw_norm_scale; // at least BF16. MatPtr post_attention_norm_scale; // at least BF16. MatPtr post_ffw_norm_scale; // at least BF16. MatPtrT ffw_gating_biases; MatPtrT ffw_output_biases; MatPtr attn_vec_einsum_w; // Use att_weights instead of this. MatPtr att_weights; // Use this instead of attn_vec_einsum_w. MatPtr key_norm_scale; // at least BF16. MatPtr query_norm_scale; // at least BF16. const LayerConfig& layer_config; // Calls `func(TensorArgs)` for each tensor which is in use for the // current `layer_config`. `other1` and `other2` are optional arguments so we // can also iterate over pairs or triples of tensors for `AdamUpdateMV`. // Public because also called by `WeightsPtrs`. template void ForEachTensor(LayerWeightsPtrs* other1, LayerWeightsPtrs* other2, Func func) { if (layer_config.type == LayerAttentionType::kVit) { // MHA. func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); func(TENSOR_ARGS(vit.attn_out_b, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead)); // Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence // must not be padded. func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked)); // MlpBlock. func(TENSOR_ARGS(vit.linear_0_w, kMustRead)); func(TENSOR_ARGS(vit.linear_0_b, kMustRead)); func(TENSOR_ARGS(vit.linear_1_w, kMustRead)); func(TENSOR_ARGS(vit.linear_1_b, kMustRead)); // LayerNorm. func(TENSOR_ARGS(vit.layer_norm_0_bias, kMustRead)); func(TENSOR_ARGS(vit.layer_norm_0_scale, kMustRead)); func(TENSOR_ARGS(vit.layer_norm_1_bias, kMustRead)); func(TENSOR_ARGS(vit.layer_norm_1_scale, kMustRead)); return; } if (layer_config.type == LayerAttentionType::kGemma) { // Either read from file, or allocated during Fixup(). func(TENSOR_ARGS(att_weights, kMaybeRead)); func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); } { func(TENSOR_ARGS(gating_einsum_w, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead)); func(TENSOR_ARGS(linear_w, kMaybeRead)); func(TENSOR_ARGS(pre_attention_norm_scale, kMustRead)); func(TENSOR_ARGS(pre_ffw_norm_scale, kMustRead)); } if (layer_config.post_norm == PostNormType::Scale) { func(TENSOR_ARGS(post_attention_norm_scale, kMustRead)); func(TENSOR_ARGS(post_ffw_norm_scale, kMustRead)); } if (layer_config.use_qk_norm) { func(TENSOR_ARGS(key_norm_scale, kMustRead)); func(TENSOR_ARGS(query_norm_scale, kMustRead)); } if (layer_config.ff_biases) { func(TENSOR_ARGS(ffw_gating_biases, kMustRead)); func(TENSOR_ARGS(ffw_output_biases, kMustRead)); } } // `ForEachTensor` // Zero-initializes all allocated tensors in the layer. void ZeroInit() { ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { if (!t.mat.HasPtr()) return; gcpp::ZeroInit(t.mat); }); } // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. void Fixup(std::vector& mat_owners, ThreadingContext& ctx); private: // Copies att_weights from `attn_vec_einsum_w`. void InitAttWeights(std::vector& mat_owners, const Allocator& allocator); // For FFN. Fast, only updates pointers. void SplitW1(); // For attention, which might not have a w2. Fast, only updates pointers. void SplitAttW1(); }; // Holds layer-independent weight metadata and pointers plus per-layer // `LayerWeightsPtrs`. The tensor data is owned by `MatOwner`. struct WeightsPtrs { explicit WeightsPtrs(const ModelConfig& config) : config_(config), tensors_(config_), finder_("", tensors_), // no suffix because these are per-model. embedder_input_embedding(finder_("c_embedding")), final_norm_scale(finder_("c_final_norm")), vit_encoder_norm_bias(finder_("enc_norm_bias")), vit_encoder_norm_scale(finder_("enc_norm_scale")), vit_img_embedding_bias(finder_("img_emb_bias")), vit_img_embedding_kernel(finder_("img_emb_kernel")), vit_img_pos_embedding(finder_("img_pos_emb")), vit_img_head_bias(finder_("img_head_bias")), vit_img_head_kernel(finder_("img_head_kernel")), mm_embed_norm(finder_("mm_embed_norm")), c_layers() { c_layers.reserve(config_.layer_configs.size()); for (size_t idx = 0; idx < config_.layer_configs.size(); ++idx) { const LayerConfig& layer_config = config_.layer_configs[idx]; c_layers.emplace_back(idx, layer_config, tensors_); } for (size_t idx = 0; idx < config_.vit_config.layer_configs.size(); ++idx) { const LayerConfig& layer_config = config_.vit_config.layer_configs[idx]; vit_layers.emplace_back(idx, layer_config, tensors_); } } ~WeightsPtrs() = default; const ModelConfig& config_; // Passed to finder_, hence must be initialized first. const TensorInfoRegistry tensors_; const MatFinder finder_; // TODO: switch to SFP? MatPtr embedder_input_embedding; MatPtr final_norm_scale; // at least BF16. // Vit parts. MatPtr vit_encoder_norm_bias; // at least BF16. MatPtr vit_encoder_norm_scale; // at least BF16. MatPtrT vit_img_embedding_bias; MatPtr vit_img_embedding_kernel; // at least BF16. MatPtr vit_img_pos_embedding; // F32? // The head maps from VitConfig::model_dim (Vit final layer) to // model_dim (LLM input). MatPtrT vit_img_head_bias; MatPtr vit_img_head_kernel; // at least BF16. MatPtr mm_embed_norm; // at least BF16. std::vector c_layers; std::vector vit_layers; const LayerWeightsPtrs* GetLayer(size_t layer) const { return &c_layers[layer]; } LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; } const LayerWeightsPtrs* VitLayer(size_t layer) const { return &vit_layers[layer]; } LayerWeightsPtrs* VitLayer(size_t layer) { return &vit_layers[layer]; } // Called via `CallT`. `other1` and `other2` are usually null, but can be // used to copy from another set of weights. Public because called by tests // and `WeightsOwner`. template void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) { LayerWeightsPtrs* other_layer1 = nullptr; LayerWeightsPtrs* other_layer2 = nullptr; func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); func(TENSOR_ARGS(final_norm_scale, kMustRead)); if (!config_.vit_config.layer_configs.empty()) { // Vit parts. func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead)); func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead)); func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead)); func(TENSOR_ARGS(vit_img_embedding_kernel, kMustRead)); func(TENSOR_ARGS(vit_img_pos_embedding, kMustRead)); func(TENSOR_ARGS(vit_img_head_bias, kMustRead)); func(TENSOR_ARGS(vit_img_head_kernel, kMustRead)); if (config_.wrapping == PromptWrapping::GEMMA_VLM) { func(TENSOR_ARGS(mm_embed_norm, kMustRead)); } } for (size_t layer_idx = 0; layer_idx < c_layers.size(); ++layer_idx) { if (other1) other_layer1 = other1->GetLayer(layer_idx); if (other2) other_layer2 = other2->GetLayer(layer_idx); GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func); } HWY_ASSERT(config_.vit_config.layer_configs.empty() == vit_layers.empty()); for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) { HWY_ASSERT(vit_layers[layer_idx].layer_config.type == LayerAttentionType::kVit); other_layer1 = other1 ? other1->VitLayer(layer_idx) : nullptr; other_layer2 = other2 ? other2->VitLayer(layer_idx) : nullptr; VitLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func); } } // `ForEachTensor` // Zero-initializes only the allocated tensors in `*this`. void ZeroInit(); // Copies only the allocated tensors in `*this` from tensors in `other`. void CopyFrom(const WeightsPtrs& other); enum class Mode { // Parallel I/O, decompress to BF16. Best for large batch sizes. kReadBF16, // Parallel I/O, insert row-wise padding. Safe default. kRead, // Best for large weights relative to available memory, especially for // frequent invocations of small batches and short sequences. Adds noise to // performance measurements due to I/O variability. kMap }; static const char* ToString(Mode mode) { switch (mode) { case Mode::kReadBF16: return "ReadBF16"; case Mode::kRead: return "Read"; case Mode::kMap: return "Map"; default: HWY_DASSERT(false); return "?"; } } // Reads tensor data from `BlobStore` or aborts on error. `map` is a user // override for whether to map blobs or read them. Returns the mode used. Mode ReadFromBlobs(const ModelStore& model, BlobReader& reader, const LoaderArgs& loader, const InferenceArgs& inference, std::vector& mat_owners, ThreadingContext& ctx); // Adds one blob for each tensor's data and returns all serialized MatPtr. std::vector AddTensorDataToWriter(BlobWriter& writer) const; private: // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Called by ReadFromBlobs. void Fixup(std::vector& mat_owners, ThreadingContext& ctx); MapPtr mapped_; }; // `WeightsPtrs` #undef TENSOR_ARGS } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_