From 4bc44d5678550f1c1a3992b7133063f701f5b4bc Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 11 Jul 2025 06:10:51 -0700 Subject: [PATCH] Minor: ModelWeightsPtrs -> WeightsPtrs PiperOrigin-RevId: 781954533 --- gemma/gemma.cc | 18 +++++++++--------- gemma/gemma.h | 4 ++-- gemma/tensor_info_test.cc | 2 +- gemma/vit.cc | 5 ++--- gemma/vit.h | 21 ++++++++++----------- gemma/weights.cc | 25 ++++++++++++------------- gemma/weights.h | 15 ++++++--------- 7 files changed, 42 insertions(+), 48 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 79ba7cf..a5718d2 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -137,7 +137,7 @@ static float EmbeddingScaling(size_t model_dim) { // Returns new image_token_position. static HWY_NOINLINE size_t EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, - const ModelConfig& model_config, const ModelWeightsPtrs& weights, + const ModelConfig& model_config, const WeightsPtrs& weights, MatStorageT& x, const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { // Image tokens just need to be copied. @@ -187,7 +187,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, // prefix-LM mode (end > 0), which must see all tokens in one batch. static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, + const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { @@ -291,7 +291,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, // token-batched `PrefillTBatch`. static HWY_NOINLINE void Transformer(const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, + const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env) { if (HWY_UNLIKELY(runtime_config.layers_output)) { @@ -324,7 +324,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config, static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, + const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { @@ -391,7 +391,7 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, // streams the token. static void DecodeStepT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, + const WeightsPtrs& weights, const SampleFunc& sample_token, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos, @@ -451,7 +451,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) { // Decode: generates one continuation token for each query in `qbatch`. static void GenerateT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, Activations& activations, + const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { // Griffin assumes that the recurrent block cache is zero-initialized. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { @@ -534,7 +534,7 @@ static void GenerateT(const ModelConfig& config, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, KVCache& kv_cache, + const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), env.row_ptrs); @@ -550,7 +550,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, // queries, and calls `GenerateT` on each batch. void GenerateBatchT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, AllQueries& all_queries, + const WeightsPtrs& weights, AllQueries& all_queries, MatMulEnv& env, TimingInfo& timing_info) { const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); @@ -568,7 +568,7 @@ void GenerateBatchT(const ModelConfig& config, void GenerateImageTokensT(const ModelConfig& config, const RuntimeConfig& runtime_config, size_t seq_len, - const ModelWeightsPtrs& weights, const Image& image, + const WeightsPtrs& weights, const Image& image, ImageTokens& image_tokens, MatMulEnv& env) { if (config.vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); diff --git a/gemma/gemma.h b/gemma/gemma.h index 27ce523..dfcb2ee 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -242,7 +242,7 @@ class Gemma { // TODO: rename to Config() const ModelConfig& GetModelConfig() const { return model_.Config(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } - const ModelWeightsPtrs& Weights() const { return weights_; } + const WeightsPtrs& Weights() const { return weights_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const InferenceArgs& Inference() const { return inference_; } @@ -274,7 +274,7 @@ class Gemma { BlobReader reader_; ModelStore model_; std::vector mat_owners_; - ModelWeightsPtrs weights_; + WeightsPtrs weights_; GemmaChatTemplate chat_template_; InferenceArgs inference_; }; diff --git a/gemma/tensor_info_test.cc b/gemma/tensor_info_test.cc index 427ecb4..8a95376 100644 --- a/gemma/tensor_info_test.cc +++ b/gemma/tensor_info_test.cc @@ -21,7 +21,7 @@ TEST(TensorInfoRegistryTest, Find) { config.Specifier().c_str()); const TensorInfoRegistry tensors(config); // Each tensor in the model should be known/found. - ModelWeightsPtrs weights(config); + WeightsPtrs weights(config); weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) { const TensorInfo* info = tensors.Find(t.mat.Name()); HWY_ASSERT_M(info, t.mat.Name()); diff --git a/gemma/vit.cc b/gemma/vit.cc index e96c61e..8efed43 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -278,7 +278,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, // kernel. The result is stored in activations.x. static HWY_NOINLINE void EmbedImagePatches(const Image& image, const ModelConfig& model_config, - const ModelWeightsPtrs& weights, + const WeightsPtrs& weights, Activations& activations, MatMulEnv& env) { const size_t model_dim = model_config.vit_config.model_dim; @@ -308,8 +308,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, } // Prefills the image tokens with the ViT encoder. -void PrefillVit(const ModelConfig& model_config, - const ModelWeightsPtrs& weights, +void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens, Activations& activations, MatMulEnv& env) { diff --git a/gemma/vit.h b/gemma/vit.h index 34d2307..d6562f6 100644 --- a/gemma/vit.h +++ b/gemma/vit.h @@ -26,17 +26,16 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \ - MatMulEnv& env); \ - \ - void PrefillVit(const ModelConfig& model_config, \ - const ModelWeightsPtrs& weights, \ - const RuntimeConfig& runtime_config, const Image& image, \ - ImageTokens& image_tokens, Activations& activations, \ - MatMulEnv& env); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \ + MatMulEnv& env); \ + \ + void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, \ + const RuntimeConfig& runtime_config, const Image& image, \ + ImageTokens& image_tokens, Activations& activations, \ + MatMulEnv& env); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/weights.cc b/gemma/weights.cc index 3fb7dee..3fdb2c1 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -203,7 +203,7 @@ static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) { } // Zero-initializes only the allocated tensors in `*this`. -void ModelWeightsPtrs::ZeroInit() { +void WeightsPtrs::ZeroInit() { ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { if (!t.mat.HasPtr()) return; gcpp::ZeroInit(t.mat); @@ -211,8 +211,8 @@ void ModelWeightsPtrs::ZeroInit() { } // Copies only the allocated tensors in `*this` from tensors in `other`. -void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) { - ForEachTensor(const_cast(&other), nullptr, +void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { + ForEachTensor(const_cast(&other), nullptr, [](const TensorArgs& t) { if (!t.mat.HasPtr()) return; HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); @@ -222,8 +222,8 @@ void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) { // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Called by WeightsOwner::Fixup. -void ModelWeightsPtrs::Fixup(std::vector& mat_owners, - hwy::ThreadPool& pool) { +void WeightsPtrs::Fixup(std::vector& mat_owners, + hwy::ThreadPool& pool) { pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { GetLayer(layer)->Fixup(mat_owners); }); @@ -233,11 +233,11 @@ void ModelWeightsPtrs::Fixup(std::vector& mat_owners, }); } -std::vector ModelWeightsPtrs::AddTensorDataToWriter( +std::vector WeightsPtrs::AddTensorDataToWriter( BlobWriter& writer) const { std::vector serialized_mat_ptrs; // ForEachTensor is non-const but the lambda does not modify *this. - const_cast(this)->ForEachTensor( + const_cast(this)->ForEachTensor( nullptr, nullptr, [&](const TensorArgs& t) { if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); @@ -506,12 +506,11 @@ static void MapOrReadAll(std::vector& tensors, BlobReader& reader, ReadBatches(reader, batches, pool); } -void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model, - BlobReader& reader, - const LoaderArgs& loader, - const InferenceArgs& inference, - std::vector& mat_owners, - hwy::ThreadPool& pool) { +void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader, + const LoaderArgs& loader, + const InferenceArgs& inference, + std::vector& mat_owners, + hwy::ThreadPool& pool) { // List of tensors to read/map, and where from. std::vector tensors; diff --git a/gemma/weights.h b/gemma/weights.h index 250450a..63ed70d 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -314,10 +314,8 @@ struct LayerWeightsPtrs { // Holds layer-independent weight metadata and pointers plus per-layer // `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. -// TODO: move `gemma-inl.h` toward dispatch at each usage. -// TODO: rename to WeightsPtrs. -struct ModelWeightsPtrs { - explicit ModelWeightsPtrs(const ModelConfig& config) +struct WeightsPtrs { + explicit WeightsPtrs(const ModelConfig& config) : config_(config), tensors_(config_), finder_("", tensors_), // no suffix because these are per-model. @@ -343,7 +341,7 @@ struct ModelWeightsPtrs { } } - ~ModelWeightsPtrs() = default; + ~WeightsPtrs() = default; const ModelConfig& config_; // Passed to finder_, hence must be initialized first. @@ -383,8 +381,7 @@ struct ModelWeightsPtrs { // used to copy from another set of weights. Public because called by tests // and `WeightsOwner`. template - void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2, - Func func) { + void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) { LayerWeightsPtrs* other_layer1 = nullptr; LayerWeightsPtrs* other_layer2 = nullptr; func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); @@ -423,7 +420,7 @@ struct ModelWeightsPtrs { // Zero-initializes only the allocated tensors in `*this`. void ZeroInit(); // Copies only the allocated tensors in `*this` from tensors in `other`. - void CopyFrom(const ModelWeightsPtrs& other); + void CopyFrom(const WeightsPtrs& other); // Reads tensor data from `BlobStore` or aborts on error. `map` is a user // override for whether to map blobs or read them. @@ -438,7 +435,7 @@ struct ModelWeightsPtrs { // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Called by ReadFromBlobs. void Fixup(std::vector& mat_owners, hwy::ThreadPool& pool); -}; // `ModelWeightsPtrs` +}; // `WeightsPtrs` #undef TENSOR_ARGS } // namespace gcpp