Minor: ModelWeightsPtrs -> WeightsPtrs

PiperOrigin-RevId: 781954533
This commit is contained in:
Jan Wassenberg 2025-07-11 06:10:51 -07:00 committed by Copybara-Service
parent fea9a07d9b
commit 4bc44d5678
7 changed files with 42 additions and 48 deletions

View File

@ -137,7 +137,7 @@ static float EmbeddingScaling(size_t model_dim) {
// Returns new image_token_position. // Returns new image_token_position.
static HWY_NOINLINE size_t static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const ModelWeightsPtrs& weights, const ModelConfig& model_config, const WeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr, MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) { size_t image_token_position = 0) {
// Image tokens just need to be copied. // Image tokens just need to be copied.
@ -187,7 +187,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
// prefix-LM mode (end > 0), which must see all tokens in one batch. // prefix-LM mode (end > 0), which must see all tokens in one batch.
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, Activations& activations, QBatch& qbatch,
MatMulEnv& env, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) { hwy::BitSet4096<>& non_eos) {
@ -291,7 +291,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
// token-batched `PrefillTBatch`. // token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer(const ModelConfig& config, static HWY_NOINLINE void Transformer(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, Activations& activations, QBatch& qbatch,
MatMulEnv& env) { MatMulEnv& env) {
if (HWY_UNLIKELY(runtime_config.layers_output)) { if (HWY_UNLIKELY(runtime_config.layers_output)) {
@ -324,7 +324,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const ModelConfig& config, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, Activations& activations, QBatch& qbatch,
MatMulEnv& env, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) { hwy::BitSet4096<>& non_eos) {
@ -391,7 +391,7 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
// streams the token. // streams the token.
static void DecodeStepT(const ModelConfig& config, static void DecodeStepT(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const WeightsPtrs& weights,
const SampleFunc& sample_token, const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch, Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
@ -451,7 +451,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// Decode: generates one continuation token for each query in `qbatch`. // Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config, static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations, const WeightsPtrs& weights, Activations& activations,
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
// Griffin assumes that the recurrent block cache is zero-initialized. // Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
@ -534,7 +534,7 @@ static void GenerateT(const ModelConfig& config,
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config, const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, KVCache& kv_cache, const WeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) { MatMulEnv& env, TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size, Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.row_ptrs); kv_cache.SeqLen(), env.row_ptrs);
@ -550,7 +550,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
// queries, and calls `GenerateT` on each batch. // queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const ModelConfig& config, void GenerateBatchT(const ModelConfig& config,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, AllQueries& all_queries, const WeightsPtrs& weights, AllQueries& all_queries,
MatMulEnv& env, TimingInfo& timing_info) { MatMulEnv& env, TimingInfo& timing_info) {
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size); runtime_config.prefill_tbatch_size);
@ -568,7 +568,7 @@ void GenerateBatchT(const ModelConfig& config,
void GenerateImageTokensT(const ModelConfig& config, void GenerateImageTokensT(const ModelConfig& config,
const RuntimeConfig& runtime_config, size_t seq_len, const RuntimeConfig& runtime_config, size_t seq_len,
const ModelWeightsPtrs& weights, const Image& image, const WeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) { ImageTokens& image_tokens, MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) { if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens."); HWY_ABORT("Model does not support generating image tokens.");

View File

@ -242,7 +242,7 @@ class Gemma {
// TODO: rename to Config() // TODO: rename to Config()
const ModelConfig& GetModelConfig() const { return model_.Config(); } const ModelConfig& GetModelConfig() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const ModelWeightsPtrs& Weights() const { return weights_; } const WeightsPtrs& Weights() const { return weights_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; } const InferenceArgs& Inference() const { return inference_; }
@ -274,7 +274,7 @@ class Gemma {
BlobReader reader_; BlobReader reader_;
ModelStore model_; ModelStore model_;
std::vector<MatOwner> mat_owners_; std::vector<MatOwner> mat_owners_;
ModelWeightsPtrs weights_; WeightsPtrs weights_;
GemmaChatTemplate chat_template_; GemmaChatTemplate chat_template_;
InferenceArgs inference_; InferenceArgs inference_;
}; };

View File

@ -21,7 +21,7 @@ TEST(TensorInfoRegistryTest, Find) {
config.Specifier().c_str()); config.Specifier().c_str());
const TensorInfoRegistry tensors(config); const TensorInfoRegistry tensors(config);
// Each tensor in the model should be known/found. // Each tensor in the model should be known/found.
ModelWeightsPtrs weights(config); WeightsPtrs weights(config);
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) { weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
const TensorInfo* info = tensors.Find(t.mat.Name()); const TensorInfo* info = tensors.Find(t.mat.Name());
HWY_ASSERT_M(info, t.mat.Name()); HWY_ASSERT_M(info, t.mat.Name());

View File

@ -278,7 +278,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
// kernel. The result is stored in activations.x. // kernel. The result is stored in activations.x.
static HWY_NOINLINE void EmbedImagePatches(const Image& image, static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelConfig& model_config, const ModelConfig& model_config,
const ModelWeightsPtrs& weights, const WeightsPtrs& weights,
Activations& activations, Activations& activations,
MatMulEnv& env) { MatMulEnv& env) {
const size_t model_dim = model_config.vit_config.model_dim; const size_t model_dim = model_config.vit_config.model_dim;
@ -308,8 +308,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
} }
// Prefills the image tokens with the ViT encoder. // Prefills the image tokens with the ViT encoder.
void PrefillVit(const ModelConfig& model_config, void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const Image& image, const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, Activations& activations, ImageTokens& image_tokens, Activations& activations,
MatMulEnv& env) { MatMulEnv& env) {

View File

@ -26,17 +26,16 @@
namespace gcpp { namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target. // Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \ #define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
namespace NAMESPACE { \ namespace NAMESPACE { \
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
MatMulEnv& env); \ MatMulEnv& env); \
\ \
void PrefillVit(const ModelConfig& model_config, \ void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, \
const ModelWeightsPtrs& weights, \ const RuntimeConfig& runtime_config, const Image& image, \
const RuntimeConfig& runtime_config, const Image& image, \ ImageTokens& image_tokens, Activations& activations, \
ImageTokens& image_tokens, Activations& activations, \ MatMulEnv& env); \
MatMulEnv& env); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE } // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the // Function declarations for each SIMD target. Allows direct call from the

View File

@ -203,7 +203,7 @@ static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) {
} }
// Zero-initializes only the allocated tensors in `*this`. // Zero-initializes only the allocated tensors in `*this`.
void ModelWeightsPtrs::ZeroInit() { void WeightsPtrs::ZeroInit() {
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
if (!t.mat.HasPtr()) return; if (!t.mat.HasPtr()) return;
gcpp::ZeroInit(t.mat); gcpp::ZeroInit(t.mat);
@ -211,8 +211,8 @@ void ModelWeightsPtrs::ZeroInit() {
} }
// Copies only the allocated tensors in `*this` from tensors in `other`. // Copies only the allocated tensors in `*this` from tensors in `other`.
void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) { void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
ForEachTensor(const_cast<ModelWeightsPtrs*>(&other), nullptr, ForEachTensor(const_cast<WeightsPtrs*>(&other), nullptr,
[](const TensorArgs& t) { [](const TensorArgs& t) {
if (!t.mat.HasPtr()) return; if (!t.mat.HasPtr()) return;
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
@ -222,8 +222,8 @@ void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
// For reshaping file tensors to the shape expected by the code. This would // For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by WeightsOwner::Fixup. // ideally already happen in the importer. Called by WeightsOwner::Fixup.
void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners, void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners); GetLayer(layer)->Fixup(mat_owners);
}); });
@ -233,11 +233,11 @@ void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
}); });
} }
std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter( std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter(
BlobWriter& writer) const { BlobWriter& writer) const {
std::vector<uint32_t> serialized_mat_ptrs; std::vector<uint32_t> serialized_mat_ptrs;
// ForEachTensor is non-const but the lambda does not modify *this. // ForEachTensor is non-const but the lambda does not modify *this.
const_cast<ModelWeightsPtrs*>(this)->ForEachTensor( const_cast<WeightsPtrs*>(this)->ForEachTensor(
nullptr, nullptr, [&](const TensorArgs& t) { nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
@ -506,12 +506,11 @@ static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
ReadBatches(reader, batches, pool); ReadBatches(reader, batches, pool);
} }
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model, void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
BlobReader& reader, const LoaderArgs& loader,
const LoaderArgs& loader, const InferenceArgs& inference,
const InferenceArgs& inference, std::vector<MatOwner>& mat_owners,
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from. // List of tensors to read/map, and where from.
std::vector<TensorToRead> tensors; std::vector<TensorToRead> tensors;

View File

@ -314,10 +314,8 @@ struct LayerWeightsPtrs {
// Holds layer-independent weight metadata and pointers plus per-layer // Holds layer-independent weight metadata and pointers plus per-layer
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. // `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
// TODO: move `gemma-inl.h` toward dispatch at each usage. struct WeightsPtrs {
// TODO: rename to WeightsPtrs. explicit WeightsPtrs(const ModelConfig& config)
struct ModelWeightsPtrs {
explicit ModelWeightsPtrs(const ModelConfig& config)
: config_(config), : config_(config),
tensors_(config_), tensors_(config_),
finder_("", tensors_), // no suffix because these are per-model. finder_("", tensors_), // no suffix because these are per-model.
@ -343,7 +341,7 @@ struct ModelWeightsPtrs {
} }
} }
~ModelWeightsPtrs() = default; ~WeightsPtrs() = default;
const ModelConfig& config_; const ModelConfig& config_;
// Passed to finder_, hence must be initialized first. // Passed to finder_, hence must be initialized first.
@ -383,8 +381,7 @@ struct ModelWeightsPtrs {
// used to copy from another set of weights. Public because called by tests // used to copy from another set of weights. Public because called by tests
// and `WeightsOwner`. // and `WeightsOwner`.
template <class Func> template <class Func>
void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2, void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) {
Func func) {
LayerWeightsPtrs* other_layer1 = nullptr; LayerWeightsPtrs* other_layer1 = nullptr;
LayerWeightsPtrs* other_layer2 = nullptr; LayerWeightsPtrs* other_layer2 = nullptr;
func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
@ -423,7 +420,7 @@ struct ModelWeightsPtrs {
// Zero-initializes only the allocated tensors in `*this`. // Zero-initializes only the allocated tensors in `*this`.
void ZeroInit(); void ZeroInit();
// Copies only the allocated tensors in `*this` from tensors in `other`. // Copies only the allocated tensors in `*this` from tensors in `other`.
void CopyFrom(const ModelWeightsPtrs& other); void CopyFrom(const WeightsPtrs& other);
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user // Reads tensor data from `BlobStore` or aborts on error. `map` is a user
// override for whether to map blobs or read them. // override for whether to map blobs or read them.
@ -438,7 +435,7 @@ struct ModelWeightsPtrs {
// For reshaping file tensors to the shape expected by the code. This would // For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by ReadFromBlobs. // ideally already happen in the importer. Called by ReadFromBlobs.
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool); void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
}; // `ModelWeightsPtrs` }; // `WeightsPtrs`
#undef TENSOR_ARGS #undef TENSOR_ARGS
} // namespace gcpp } // namespace gcpp