// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #include #include "compression/compress.h" #include "gemma/common.h" #include "gemma/configs.h" #include "util/allocator.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" namespace gcpp { template struct CompressedLayer { // No ctor/dtor, allocated via AllocateAligned. using Weight = typename TConfig::Weight; // If weights are f32, also f32; otherwise at least bf16. Useful for ops that // do not yet support smaller compressed types, or require at least bf16. When // weights are f32, we also want such tensors to be f32. using WeightF32OrBF16 = hwy::If(), float, hwy::bfloat16_t>; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; static constexpr size_t kQKVEinsumWSize = (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; // 2x for (gelu gating vector, gated vector) static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr bool kFFBiases = TConfig::kFFBiases; static constexpr PostNormType kPostNorm = TConfig::kPostNorm; static constexpr size_t kAOBiasDim = TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; static constexpr size_t kGriffinDim = TConfig::kGriffinLayers > 0 ? kModelDim : 0; template using ArrayT = CompressedArray; union { struct { ArrayT attn_vec_einsum_w; ArrayT qkv_einsum_w; ArrayT attention_output_biases; }; struct { ArrayT linear_x_w; ArrayT linear_x_biases; ArrayT linear_y_w; ArrayT linear_y_biases; ArrayT linear_out_w; ArrayT linear_out_biases; ArrayT conv_w; ArrayT conv_biases; ArrayT gate_w; ArrayT gate_biases; ArrayT a; } griffin; }; ArrayT gating_einsum_w; ArrayT linear_w; // We don't yet have an RMSNorm that accepts all Weight. ArrayT pre_attention_norm_scale; ArrayT pre_ffw_norm_scale; ArrayT post_attention_norm_scale; ArrayT post_ffw_norm_scale; ArrayT ffw_gating_biases; ArrayT ffw_output_biases; // Reshaped attention; not loaded from disk via ForEachTensor. ArrayT att_weights; // Initializes att_weights from attn_vec_einsum_w, hence this must be called // after loading weights via ForEachTensor. // TODO: update compression/convert_weights to bake this in. void Reshape() { PROFILER_ZONE("Startup.Reshape"); constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kQKVDim = TConfig::kQKVDim; // Would have to implement a CompressTraits::Copy for NUQ. static_assert(!hwy::IsSame()); // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. for (size_t m = 0; m < kModelDim; ++m) { Weight* HWY_RESTRICT out_row = att_weights.data() + m * kHeads * kQKVDim; for (size_t h = 0; h < kHeads; ++h) { hwy::CopyBytes( attn_vec_einsum_w.data() + h * kModelDim * kQKVDim + m * kQKVDim, out_row + h * kQKVDim, kQKVDim * sizeof(Weight)); } } } }; // Array instead of single large allocation for parallel mem init. Split out // of CompressedWeights so that only these pointers are initialized, not the // CompressedArray. template struct CompressedLayerPointers { explicit CompressedLayerPointers(hwy::ThreadPool& pool) { pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { this->c_layers[task] = hwy::AllocateAligned>(1); }); } using CLayer = CompressedLayer; std::array, TConfig::kLayers> c_layers; }; template struct CompressedWeights { // Must be allocated via AllocateAligned and initialized with placement new. void* operator new(size_t, void* addr) { return addr; } void* operator new(size_t) = delete; void* operator new[](size_t) = delete; void operator delete(void*) = delete; void operator delete[](void*) = delete; using Weight = typename TConfig::Weight; using WeightF32OrInputT = hwy::If(), float, EmbedderInputT>; CompressedArray embedder_input_embedding; using WeightF32OrBF16 = hwy::If(), float, hwy::bfloat16_t>; CompressedArray final_norm_scale; // Must be last so that the other arrays remain aligned. CompressedLayerPointers c_layer_ptrs; explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {} // Called by weights.cc after ForEachTensor. void Reshape() { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { GetLayer(layer)->Reshape(); } } void ZeroInit() { hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding)); hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale)); for (int i = 0; i < TConfig::kLayers; ++i) { hwy::ZeroBytes(GetLayer(i), sizeof(*GetLayer(i))); } } const CompressedLayer* GetLayer(size_t layer) const { return c_layer_ptrs.c_layers[layer].get(); } CompressedLayer* GetLayer(size_t layer) { return c_layer_ptrs.c_layers[layer].get(); } }; // ---------------------------------------------------------------------------- // Interface template struct AllocateCompressedWeights { ByteStorageT operator()(hwy::ThreadPool& pool) const { using TWeights = CompressedWeights; ByteStorageT weights_u8 = AllocateSizeof(); TWeights* weights = reinterpret_cast(weights_u8.get()); new (weights) TWeights(pool); return weights_u8; } }; template struct ZeroInitCompressedWeights { void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); weights.ZeroInit(); } }; template struct ReshapeCompressedWeights { void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); weights.Reshape(); } }; // TODO: also add RandInitCompressedWeights template struct DeleteCompressedWeights { void operator()(ByteStorageT& weights_u8) const { CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); weights.~CompressedWeights(); } }; ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, Type weight_type, hwy::ThreadPool& pool); void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); // ---------------------------------------------------------------------------- // Iterators // We rely on `if constexpr` to ensure raw_weights->member is only compiled // when valid, i.e., kHaveRaw == true, but the IDE analysis does not understand // this, hence hide the member access from it. #if HWY_IDE #define GEMMA_MEMBER(aggregate, member) nullptr #else #define GEMMA_MEMBER(aggregate, member) aggregate->member #endif // Used by ForEachTensor for tensors that are not in a layer. #define GEMMA_CALL_TOP_FUNC(name, member) \ { \ const float* raw_tensor = nullptr; \ if constexpr (kHaveRaw) { \ raw_tensor = GEMMA_MEMBER(raw_weights, member.data()); \ } \ func(name, raw_tensor, c_weights.member); \ } // Used by ForEachTensor for per-layer tensors. Writes into name_buf. #define GEMMA_CALL_FUNC(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ { \ const float* raw_tensor = nullptr; \ if constexpr (kHaveRaw) { \ raw_tensor = GEMMA_MEMBER(raw_layer, member.data()); \ } \ func(name_buf, raw_tensor, c_layer->member); \ } // Calls func(name, float*, CompressedArray&) for each tensor. float* is // null if raw_weights is nullptr, e.g., when loading weights from BlobStore. // Otherwise, RawLayer must be specified and we pass a float* pointing to the // raw float weights for that tensor for use by compress_weights.cc. // // This avoids repeating the list of tensors between loading and compressing, // while also avoiding dependency on raw_weights.h. // // This only calls Func for tensors that TConfig requests/specifies, which means // scale() is uninitialized for the other tensors, so their data_scale1() must // not be called. (In other words, if the config doesn't specify a tensor, it // shouldn't be used.) template void ForEachTensor(RawWeightsPtr raw_weights, CompressedWeights& c_weights, Func& func) { constexpr bool kHaveRaw = !hwy::IsSame(); GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const RawLayer* raw_layer = nullptr; if constexpr (kHaveRaw) { raw_layer = raw_weights->GetLayer(idx); } CompressedLayer* c_layer = c_weights.GetLayer(idx); GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale); GEMMA_CALL_FUNC("gating_ein", gating_einsum_w); GEMMA_CALL_FUNC("linear_w", linear_w); if (type == LayerAttentionType::kGemma) { GEMMA_CALL_FUNC("qkv_ein", qkv_einsum_w); GEMMA_CALL_FUNC("att_ein", attn_vec_einsum_w); } else { GEMMA_CALL_FUNC("gr_lin_x_w", griffin.linear_x_w); GEMMA_CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases); GEMMA_CALL_FUNC("gr_lin_y_w", griffin.linear_y_w); GEMMA_CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases); GEMMA_CALL_FUNC("gr_lin_out_w", griffin.linear_out_w); GEMMA_CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases); GEMMA_CALL_FUNC("gr_conv_w", griffin.conv_w); GEMMA_CALL_FUNC("gr_conv_b", griffin.conv_biases); GEMMA_CALL_FUNC("gr_gate_w", griffin.gate_w); GEMMA_CALL_FUNC("gr_gate_b", griffin.gate_biases); GEMMA_CALL_FUNC("gr_a", griffin.a); } GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale); if (TConfig::kPostNorm == PostNormType::Scale) { GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale); GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale); } if (TConfig::kFFBiases) { GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases); GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases); } if (TConfig::kSoftmaxAttnOutputBiases && type == LayerAttentionType::kGemma) { GEMMA_CALL_FUNC("attn_ob", attention_output_biases); } } #undef GEMMA_CALL_FUNC #undef GEMMA_CALL_TOP_FUNC } // ForEachTensor #define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member) #define GEMMA_CALL_TOP_FUNC2(name, member) \ func(name, weights1.member, weights2.member) #define GEMMA_CALL_TOP_FUNC3(name, member) \ func(name, weights1.member, weights2.member, weights3.member) #define GEMMA_CALL_TOP_FUNC4(name, member) \ func(name, weights1.member, weights2.member, \ weights3.member, weights4.member) #define GEMMA_CALL_LAYER_FUNC1(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member) #define GEMMA_CALL_LAYER_FUNC2(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member, layer2.member) #define GEMMA_CALL_LAYER_FUNC3(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member, layer2.member, layer3.member) #define GEMMA_CALL_LAYER_FUNC4(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ func(name_buf, layer1.member, layer2.member, layer3.member, layer4.member) #define GEMMA_CALL_ALL_LAYER_FUNC(N) \ if (type == LayerAttentionType::kGemma) { \ GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ } else { \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ } \ GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ if (TConfig::kPostNorm == PostNormType::Scale) { \ GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ } \ GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ if (TConfig::kFFBiases) { \ GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ } \ if (TConfig::kSoftmaxAttnOutputBiases && \ type == LayerAttentionType::kGemma) { \ GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ } template void ForEachTensor1(Func& func, const CompressedWeights& weights1) { GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const CompressedLayer& layer1 = *weights1.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(1) } } template void ForEachTensor1(Func& func, CompressedWeights& weights1) { GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); CompressedLayer& layer1 = *weights1.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(1) } } template void ForEachTensor2(Func& func, const CompressedWeights& weights1, CompressedWeights& weights2) { GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const CompressedLayer& layer1 = *weights1.GetLayer(idx); CompressedLayer& layer2 = *weights2.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(2) } } template void ForEachTensor4(Func& func, const CompressedWeights& weights1, CompressedWeights& weights2, CompressedWeights& weights3, CompressedWeights& weights4) { GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); const CompressedLayer& layer1 = *weights1.GetLayer(idx); CompressedLayer& layer2 = *weights2.GetLayer(idx); CompressedLayer& layer3 = *weights3.GetLayer(idx); CompressedLayer& layer4 = *weights4.GetLayer(idx); GEMMA_CALL_ALL_LAYER_FUNC(4) } } #undef GEMMA_CALL_TOP_FUNC1 #undef GEMMA_CALL_TOP_FUNC2 #undef GEMMA_CALL_TOP_FUNC3 #undef GEMMA_CALL_TOP_FUNC4 #undef GEMMA_CALL_LAYER_FUNC1 #undef GEMMA_CALL_LAYER_FUNC2 #undef GEMMA_CALL_LAYER_FUNC3 #undef GEMMA_CALL_LAYER_FUNC4 #undef GEMMA_CALL_ALL_LAYER_FUNC } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_