// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #include #include #include #include #include #include #include #include "compression/blob_store.h" // BlobWriter #include "compression/shared.h" // IsF32 #include "gemma/configs.h" // ModelConfig #include "gemma/model_store.h" // ModelStore #include "gemma/tensor_info.h" // TensorInfoRegistry #include "util/mat.h" // MatPtr #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { // Argument passed to the `ForEachTensor` callback. struct TensorArgs { // `other_mat1` and `other_mat2` can be nullptr, or tensor(s) of the same // name/type from another `LayerWeightsPtrs` for iterating over tensor pairs // (for copying) or triples (for `AdamUpdateMV`). Set by `TENSOR_ARGS`. // `flags` is a combination of zero or more `Flags`. TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2, int flags) : mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) { // Does not make sense to combine both flags. HWY_ASSERT(flags != (kMaybeRead | kOnlyAllocate)); } MatPtr& mat; const MatPtr* other_mat1; // either/both can be nullptr. const MatPtr* other_mat2; // TODO: freestanding enum class instead? These are mutually exclusive. enum Flags { // Read the tensor from the file and abort if it is not found. kMustRead = 0, // Not an error if the tensor is not present in the file. For example, // the _w1/_w2 tensors are not always present. kMaybeRead = 1, // Do not attempt to read, just allocate the tensor. Used for `Reshape`. kOnlyAllocate = 2, }; const int flags; }; // Shorthand for creating the argument to the `ForEachTensor` callback. A macro // seems less bad than member pointer syntax. #define TENSOR_ARGS(mat, flag) \ TensorArgs(mat, other1 ? &other1->mat : nullptr, \ other2 ? &other2->mat : nullptr, TensorArgs::flag) // Per-layer weight metadata and pointers. The tensor data is owned by // `WeightsOwner`. Note that this class could be type-erased: member functions // do not actually use the `Weight` template argument. See `WeightsPtrs`. // `TensorInfoRegistry` (constructed from `ModelConfig`) is the source of truth // for all tensor shapes. template struct LayerWeightsPtrs { static inline std::string Concat(const char* base_name, const std::string& suffix) { return std::string(base_name) + suffix; } // Initializes tensor metadata without allocating. LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) : suffix_(LayerSuffix(layer_idx)), attn_vec_einsum_w(Concat("att_ein", suffix_), tensors), qkv_einsum_w(Concat("qkv_ein", suffix_), tensors), qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors), qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors), attention_output_biases(Concat("attn_ob", suffix_), tensors), griffin( {.linear_x_w = {Concat("gr_lin_x_w", suffix_), tensors}, .linear_x_biases = {Concat("gr_lin_x_b", suffix_), tensors}, .linear_y_w = {Concat("gr_lin_y_w", suffix_), tensors}, .linear_y_biases = {Concat("gr_lin_y_b", suffix_), tensors}, .linear_out_w = {Concat("gr_lin_out_w", suffix_), tensors}, .linear_out_biases = {Concat("gr_lin_out_b", suffix_), tensors}, .conv_w = {Concat("gr_conv_w", suffix_), tensors}, .conv_biases = {Concat("gr_conv_b", suffix_), tensors}, .gate_w = {Concat("gr_gate_w", suffix_), tensors}, .gate_biases = {Concat("gr_gate_b", suffix_), tensors}, .a = {Concat("gr_a", suffix_), tensors}}), // MultiHeadDotProductAttention. vit({.attn_out_w = {Concat("attn_out_w", suffix_), tensors}, .attn_out_b = {Concat("attn_out_b", suffix_), tensors}, .qkv_einsum_w = {Concat("qkv_ein_w", suffix_), tensors}, .qkv_einsum_b = {Concat("qkv_ein_b", suffix_), tensors}, .linear_0_w = {Concat("linear_0_w", suffix_), tensors}, .linear_0_b = {Concat("linear_0_b", suffix_), tensors}, .linear_1_w = {Concat("linear_1_w", suffix_), tensors}, .linear_1_b = {Concat("linear_1_b", suffix_), tensors}, .layer_norm_0_bias = {Concat("ln_0_bias", suffix_), tensors}, .layer_norm_0_scale = {Concat("ln_0_scale", suffix_), tensors}, .layer_norm_1_bias = {Concat("ln_1_bias", suffix_), tensors}, .layer_norm_1_scale = {Concat("ln_1_scale", suffix_), tensors}}), gating_einsum_w(Concat("gating_ein", suffix_), tensors), gating_einsum_w1(Concat("gating1_w", suffix_), tensors), gating_einsum_w2(Concat("gating2_w", suffix_), tensors), linear_w(Concat("linear_w", suffix_), tensors), pre_attention_norm_scale(Concat("pre_att_ns", suffix_), tensors), pre_ffw_norm_scale(Concat("pre_ff_ns", suffix_), tensors), post_attention_norm_scale(Concat("post_att_ns", suffix_), tensors), post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors), ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors), ffw_output_biases(Concat("ffw_out_b", suffix_), tensors), att_weights(Concat("att_w", suffix_), tensors), key_norm_scale(Concat("key_norm", suffix_), tensors), query_norm_scale(Concat("query_norm", suffix_), tensors), layer_config(config) {} ~LayerWeightsPtrs() = default; const std::string suffix_; // If weights are f32, also f32; otherwise at least bf16. Useful for ops that // do not yet support smaller compressed types, or require at least bf16. When // weights are f32, we also want such tensors to be f32. // If weights are complex, this is also complex. using WeightF32OrBF16 = hwy::If>(), std::complex, hwy::If(), double, hwy::If(), float, BF16>>>; MatPtrT attn_vec_einsum_w; // qkv_einsum_w holds 2 different matrices, which may be separated out. // On reading, which is used depends on what is in the file. // At inference, the one with a non-null ptr is used. MatPtrT qkv_einsum_w; MatPtrT qkv_einsum_w1; MatPtrT qkv_einsum_w2; MatPtrT attention_output_biases; struct { MatPtrT linear_x_w; MatPtrT linear_x_biases; MatPtrT linear_y_w; MatPtrT linear_y_biases; MatPtrT linear_out_w; MatPtrT linear_out_biases; MatPtrT conv_w; MatPtrT conv_biases; MatPtrT gate_w; MatPtrT gate_biases; MatPtrT a; } griffin; struct { // MultiHeadDotProductAttention. MatPtrT attn_out_w; MatPtrT attn_out_b; MatPtrT qkv_einsum_w; MatPtrT qkv_einsum_b; // MlpBlock. MatPtrT linear_0_w; MatPtrT linear_0_b; MatPtrT linear_1_w; MatPtrT linear_1_b; // LayerNorm. MatPtrT layer_norm_0_bias; MatPtrT layer_norm_0_scale; MatPtrT layer_norm_1_bias; MatPtrT layer_norm_1_scale; } vit; // gating_einsum_w holds 2 different matrices, which may be separated out. // On reading, which is used depends on what is in the file. // At inference, the one with a non-null ptr is used. MatPtrT gating_einsum_w; MatPtrT gating_einsum_w1; MatPtrT gating_einsum_w2; MatPtrT linear_w; // We don't yet have an RMSNorm that accepts all Weight. MatPtrT pre_attention_norm_scale; MatPtrT pre_ffw_norm_scale; MatPtrT post_attention_norm_scale; MatPtrT post_ffw_norm_scale; MatPtrT ffw_gating_biases; MatPtrT ffw_output_biases; MatPtrT att_weights; // For Reshape(); kOnlyAllocate. MatPtrT key_norm_scale; MatPtrT query_norm_scale; const LayerConfig& layer_config; // Calls `func(TensorArgs)` for each tensor which is in use for the // current `layer_config`. `other1` and `other2` are optional arguments so we // can also iterate over pairs or triples of tensors for `AdamUpdateMV`. // Public because also called by `WeightsPtrs`. template void ForEachTensor(const LayerWeightsPtrs* other1, const LayerWeightsPtrs* other2, Func func) { if (layer_config.type == LayerAttentionType::kVit) { // MHA. func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); func(TENSOR_ARGS(vit.attn_out_b, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead)); // MlpBlock. func(TENSOR_ARGS(vit.linear_0_w, kMustRead)); func(TENSOR_ARGS(vit.linear_0_b, kMustRead)); func(TENSOR_ARGS(vit.linear_1_w, kMustRead)); func(TENSOR_ARGS(vit.linear_1_b, kMustRead)); // LayerNorm. func(TENSOR_ARGS(vit.layer_norm_0_bias, kMustRead)); func(TENSOR_ARGS(vit.layer_norm_0_scale, kMustRead)); func(TENSOR_ARGS(vit.layer_norm_1_bias, kMustRead)); func(TENSOR_ARGS(vit.layer_norm_1_scale, kMustRead)); return; } if (layer_config.type == LayerAttentionType::kGemma) { // Not read, will be filled by Reshape() from `attn_vec_einsum_w`. func(TENSOR_ARGS(att_weights, kOnlyAllocate)); func(TENSOR_ARGS(attn_vec_einsum_w, kMustRead)); func(TENSOR_ARGS(qkv_einsum_w, kMustRead)); func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); } else { func(TENSOR_ARGS(griffin.linear_x_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_x_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_y_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); func(TENSOR_ARGS(griffin.conv_w, kMustRead)); func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); func(TENSOR_ARGS(griffin.gate_w, kMustRead)); func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); func(TENSOR_ARGS(griffin.a, kMustRead)); } { func(TENSOR_ARGS(gating_einsum_w, kMustRead)); func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead)); func(TENSOR_ARGS(linear_w, kMustRead)); func(TENSOR_ARGS(pre_attention_norm_scale, kMustRead)); func(TENSOR_ARGS(pre_ffw_norm_scale, kMustRead)); } if (layer_config.post_norm == PostNormType::Scale) { func(TENSOR_ARGS(post_attention_norm_scale, kMustRead)); func(TENSOR_ARGS(post_ffw_norm_scale, kMustRead)); } if (layer_config.use_qk_norm) { func(TENSOR_ARGS(key_norm_scale, kMustRead)); func(TENSOR_ARGS(query_norm_scale, kMustRead)); } if (layer_config.ff_biases) { func(TENSOR_ARGS(ffw_gating_biases, kMustRead)); func(TENSOR_ARGS(ffw_output_biases, kMustRead)); } if (layer_config.softmax_attn_output_biases && layer_config.type == LayerAttentionType::kGemma) { func(TENSOR_ARGS(attention_output_biases, kMustRead)); } } // `ForEachTensor` // Zero-initializes all allocated tensors in the layer. void ZeroInit() { ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { if (!t.mat.HasPtr()) return; gcpp::ZeroInit(t.mat); }); } void RandInit(float stddev, std::mt19937& gen) { ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { if (!t.mat.HasPtr()) return; gcpp::RandInit(t.mat, stddev, gen); }); } // Allocates memory for all the tensors in the layer. Note that this is slow // (non-parallel) and only used for a stand-alone layer. void AllocateForTest(MatOwners& mat_owners) { ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { // `backprop/` does not use row accessors and hence requires kPacked. mat_owners.AllocateFor(t.mat, MatPadding::kPacked); }); } // Initializes att_weights from `attn_vec_einsum_w`, hence this must be called // after reading weights via `ForEachTensor`. // TODO: update compression/convert_weights to bake this in. void Reshape() { // NUQ is handled by a specialization in weights.cc. HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); const size_t model_dim = layer_config.model_dim; const size_t heads = layer_config.heads; const size_t qkv_dim = layer_config.qkv_dim; // Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim]. HWY_ASSERT(att_weights.HasPtr()); HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType()); HWY_ASSERT(att_weights.Rows() == model_dim); HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); HWY_ASSERT(attn_vec_einsum_w.HasPtr()); HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim); HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim); const size_t T_bytes = att_weights.ElementBytes(); for (size_t m = 0; m < model_dim; ++m) { uint8_t* HWY_RESTRICT out_row = reinterpret_cast(att_weights.Row(m)); for (size_t h = 0; h < heads; ++h) { hwy::CopyBytes(attn_vec_einsum_w.Row(h * model_dim + m), out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes); } } att_weights.SetScale(attn_vec_einsum_w.Scale()); } }; // Holds layer-independent weight metadata and pointers plus per-layer // `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. As with // `LayerWeightsPtrs`, this class could be type-erased: member functions do not // actually use the `Weight` template argument. The template does allow user // code to dispatch only once. However, most tensors are large enough that // dispatch at each usage would be feasible. // TODO: move `gemma-inl.h` toward dispatch at each usage. // TODO: rename to WeightsPtrs. template struct ModelWeightsPtrs { using WeightT = Weight; explicit ModelWeightsPtrs(const ModelConfig& config) : tensors_(config), // No suffix, these are per-model. embedder_input_embedding("c_embedding", tensors_), final_norm_scale("c_final_norm", tensors_), vit_encoder_norm_bias("enc_norm_bias", tensors_), vit_encoder_norm_scale("enc_norm_scale", tensors_), vit_img_embedding_bias("img_emb_bias", tensors_), vit_img_embedding_kernel("img_emb_kernel", tensors_), vit_img_pos_embedding("img_pos_emb", tensors_), vit_img_head_bias("img_head_bias", tensors_), vit_img_head_kernel("img_head_kernel", tensors_), mm_embed_norm("mm_embed_norm", tensors_), weights_config(config) { c_layers.reserve(config.layer_configs.size()); for (size_t idx = 0; idx < config.layer_configs.size(); ++idx) { const LayerConfig& layer_config = config.layer_configs[idx]; c_layers.emplace_back(idx, layer_config, tensors_); } for (size_t idx = 0; idx < config.vit_config.layer_configs.size(); ++idx) { const LayerConfig& layer_config = config.vit_config.layer_configs[idx]; vit_layers.emplace_back(idx, layer_config, tensors_); } } ~ModelWeightsPtrs() = default; // = F32 if weights are F32, else BF16. using WeightF32OrBF16 = typename LayerWeightsPtrs::WeightF32OrBF16; // Passed to all `MatPtrT` initializers, hence must be initialized first. const TensorInfoRegistry tensors_; // TODO: switch to SFP? MatPtrT embedder_input_embedding; MatPtrT final_norm_scale; // Vit parts. MatPtrT vit_encoder_norm_bias; MatPtrT vit_encoder_norm_scale; MatPtrT vit_img_embedding_bias; MatPtrT vit_img_embedding_kernel; MatPtrT vit_img_pos_embedding; // The head maps from VitConfig::model_dim (Vit final layer) to // model_dim (LLM input). MatPtrT vit_img_head_bias; MatPtrT vit_img_head_kernel; MatPtrT mm_embed_norm; const ModelConfig& weights_config; std::vector> c_layers; std::vector> vit_layers; const LayerWeightsPtrs* GetLayer(size_t layer) const { return &c_layers[layer]; } LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; } const LayerWeightsPtrs* VitLayer(size_t layer) const { return &vit_layers[layer]; } LayerWeightsPtrs* VitLayer(size_t layer) { return &vit_layers[layer]; } // Called via `CallT`. `other1` and `other2` are usually null, but can be // used to copy from another set of weights. Public because called by tests // and `WeightsOwner`. template void ForEachTensor(const ModelWeightsPtrs* other1, const ModelWeightsPtrs* other2, Func func) { const LayerWeightsPtrs* other_layer1 = nullptr; const LayerWeightsPtrs* other_layer2 = nullptr; func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); func(TENSOR_ARGS(final_norm_scale, kMustRead)); if (!weights_config.vit_config.layer_configs.empty()) { // Vit parts. func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead)); func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead)); func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead)); func(TENSOR_ARGS(vit_img_embedding_kernel, kMustRead)); func(TENSOR_ARGS(vit_img_pos_embedding, kMustRead)); func(TENSOR_ARGS(vit_img_head_bias, kMustRead)); func(TENSOR_ARGS(vit_img_head_kernel, kMustRead)); if (weights_config.wrapping == PromptWrapping::GEMMA_VLM) { func(TENSOR_ARGS(mm_embed_norm, kMustRead)); } } for (size_t layer_idx = 0; layer_idx < c_layers.size(); ++layer_idx) { if (other1) other_layer1 = other1->GetLayer(layer_idx); if (other2) other_layer2 = other2->GetLayer(layer_idx); GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func); } HWY_ASSERT(weights_config.vit_config.layer_configs.empty() == vit_layers.empty()); for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) { HWY_ASSERT(vit_layers[layer_idx].layer_config.type == LayerAttentionType::kVit); other_layer1 = other1 ? other1->VitLayer(layer_idx) : nullptr; other_layer2 = other2 ? other2->VitLayer(layer_idx) : nullptr; VitLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func); } } // `ForEachTensor` // Zero-initializes only the allocated tensors in `*this`. void ZeroInit() { ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { if (!t.mat.HasPtr()) return; gcpp::ZeroInit(t.mat); }); } void RandInit(float stddev, std::mt19937& gen) { ForEachTensor(nullptr, nullptr, [stddev, &gen](const TensorArgs& t) { if (!t.mat.HasPtr()) return; gcpp::RandInit(t.mat, stddev, gen); }); } // Copies only the allocated tensors in `*this` from tensors in `other`. void CopyFrom(const ModelWeightsPtrs& other) { ForEachTensor(&other, nullptr, [](const TensorArgs& t) { if (!t.mat.HasPtr()) return; HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); CopyMat(*t.other_mat1, t.mat); }); } // Instead of reading, only allocates memory for all tensors. Used by // `optimizer.cc` via the `Gemma` constructor without weights. void AllocateForTest(MatOwners& mat_owners, hwy::ThreadPool& pool) { // First get a list of all the tensors. std::vector all_mat; all_mat.reserve(10 * c_layers.size()); ForEachTensor(nullptr, nullptr, [&all_mat](const TensorArgs& t) { all_mat.push_back(&t.mat); }); // `backprop/` does not use row accessors and hence requires kPacked. mat_owners.AllocateFor(all_mat, MatPadding::kPacked, pool); } // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Must be called after reading and // updating the attention weights. void Reshape(hwy::ThreadPool& pool) { pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) { GetLayer(layer)->Reshape(); }); pool.Run(0, vit_layers.size(), [this](uint64_t layer, size_t /*thread*/) { VitLayer(layer)->Reshape(); }); } }; // `WeightsPtrs` #undef TENSOR_ARGS // Type-erased facade for `WeightsPtrs`, stored inside the non-template // `Gemma`. Also owns the underlying memory. class WeightsOwner { public: // `weight_type` is obtained from `ModelConfig` in `ModelStore`. WeightsOwner(Type weight_type) : weight_type_(weight_type) {} // Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`, // allocates memory and reshapes. Aborts on error. void ReadOrAllocate(const ModelStore& model, BlobReader& reader, hwy::ThreadPool& pool); // Calls `func(std::unique_ptr>&, args)`. `func` typically // calls `ForEachTensor`. template decltype(auto) CallT(const Func& func, TArgs&&... args) const { if (HWY_LIKELY(weight_type_ == Type::kSFP)) { return func(sfp_weights_, std::forward(args)...); } else if (weight_type_ == Type::kNUQ) { return func(nuq_weights_, std::forward(args)...); } else if (weight_type_ == Type::kF32) { return func(float_weights_, std::forward(args)...); } else if (weight_type_ == Type::kBF16) { return func(bf16_weights_, std::forward(args)...); } return HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_)); } // For writers: // Adds one blob for each tensor's data and returns all serialized MatPtr. std::vector AddTensorDataToWriter(BlobWriter& writer) const; // For backprop/: // Only allocates; must subsequently call `ZeroInit` or `RandInit`. void AllocateForTest(const ModelConfig& config, hwy::ThreadPool& pool); void ZeroInit(); void RandInit(float stddev, std::mt19937& gen); // F32 or F64 only. void LogWeightStatsF32(); ModelWeightsPtrs* GetF32() const { HWY_ASSERT(weight_type_ == Type::kF32); return float_weights_.get(); } // Usually taken care of by `ReadOrAllocate`, but must also be called by // `optimize_test, which updates the attention weights from which this copies. void Reshape(hwy::ThreadPool& pool); private: Type weight_type_; // Allocates `*_weights_`, but not yet the tensors inside. This is split out // of `CallT` so that can be const. void AllocatePointer(const ModelConfig& config); // Only one is non-null, determined by `weight_type_`. std::unique_ptr> float_weights_; std::unique_ptr> bf16_weights_; std::unique_ptr> sfp_weights_; std::unique_ptr> nuq_weights_; // Owns the memory referenced by all `MatPtr`. MatOwners mat_owners_; }; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_