Minor: ModelWeightsPtrs -> WeightsPtrs

PiperOrigin-RevId: 781954533
This commit is contained in:
Jan Wassenberg 2025-07-11 06:10:51 -07:00 committed by Copybara-Service
parent fea9a07d9b
commit 4bc44d5678
7 changed files with 42 additions and 48 deletions

View File

@ -137,7 +137,7 @@ static float EmbeddingScaling(size_t model_dim) {
// Returns new image_token_position.
static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
const ModelConfig& model_config, const WeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
// Image tokens just need to be copied.
@ -187,7 +187,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
// prefix-LM mode (end > 0), which must see all tokens in one batch.
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
@ -291,7 +291,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
// token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env) {
if (HWY_UNLIKELY(runtime_config.layers_output)) {
@ -324,7 +324,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
@ -391,7 +391,7 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
// streams the token.
static void DecodeStepT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights,
const WeightsPtrs& weights,
const SampleFunc& sample_token,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
@ -451,7 +451,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const WeightsPtrs& weights, Activations& activations,
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
@ -534,7 +534,7 @@ static void GenerateT(const ModelConfig& config,
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, KVCache& kv_cache,
const WeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) {
Activations activations(config, runtime_config.prefill_tbatch_size,
kv_cache.SeqLen(), env.row_ptrs);
@ -550,7 +550,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
// queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, AllQueries& all_queries,
const WeightsPtrs& weights, AllQueries& all_queries,
MatMulEnv& env, TimingInfo& timing_info) {
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
runtime_config.prefill_tbatch_size);
@ -568,7 +568,7 @@ void GenerateBatchT(const ModelConfig& config,
void GenerateImageTokensT(const ModelConfig& config,
const RuntimeConfig& runtime_config, size_t seq_len,
const ModelWeightsPtrs& weights, const Image& image,
const WeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");

View File

@ -242,7 +242,7 @@ class Gemma {
// TODO: rename to Config()
const ModelConfig& GetModelConfig() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const ModelWeightsPtrs& Weights() const { return weights_; }
const WeightsPtrs& Weights() const { return weights_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; }
@ -274,7 +274,7 @@ class Gemma {
BlobReader reader_;
ModelStore model_;
std::vector<MatOwner> mat_owners_;
ModelWeightsPtrs weights_;
WeightsPtrs weights_;
GemmaChatTemplate chat_template_;
InferenceArgs inference_;
};

View File

@ -21,7 +21,7 @@ TEST(TensorInfoRegistryTest, Find) {
config.Specifier().c_str());
const TensorInfoRegistry tensors(config);
// Each tensor in the model should be known/found.
ModelWeightsPtrs weights(config);
WeightsPtrs weights(config);
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
const TensorInfo* info = tensors.Find(t.mat.Name());
HWY_ASSERT_M(info, t.mat.Name());

View File

@ -278,7 +278,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
// kernel. The result is stored in activations.x.
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelConfig& model_config,
const ModelWeightsPtrs& weights,
const WeightsPtrs& weights,
Activations& activations,
MatMulEnv& env) {
const size_t model_dim = model_config.vit_config.model_dim;
@ -308,8 +308,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
}
// Prefills the image tokens with the ViT encoder.
void PrefillVit(const ModelConfig& model_config,
const ModelWeightsPtrs& weights,
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, Activations& activations,
MatMulEnv& env) {

View File

@ -26,17 +26,16 @@
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
MatMulEnv& env); \
\
void PrefillVit(const ModelConfig& model_config, \
const ModelWeightsPtrs& weights, \
const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, Activations& activations, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
MatMulEnv& env); \
\
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, \
const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, Activations& activations, \
MatMulEnv& env); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the

View File

@ -203,7 +203,7 @@ static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) {
}
// Zero-initializes only the allocated tensors in `*this`.
void ModelWeightsPtrs::ZeroInit() {
void WeightsPtrs::ZeroInit() {
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
gcpp::ZeroInit(t.mat);
@ -211,8 +211,8 @@ void ModelWeightsPtrs::ZeroInit() {
}
// Copies only the allocated tensors in `*this` from tensors in `other`.
void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
ForEachTensor(const_cast<ModelWeightsPtrs*>(&other), nullptr,
void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
ForEachTensor(const_cast<WeightsPtrs*>(&other), nullptr,
[](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
@ -222,8 +222,8 @@ void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners);
});
@ -233,11 +233,11 @@ void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
});
}
std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter(
BlobWriter& writer) const {
std::vector<uint32_t> serialized_mat_ptrs;
// ForEachTensor is non-const but the lambda does not modify *this.
const_cast<ModelWeightsPtrs*>(this)->ForEachTensor(
const_cast<WeightsPtrs*>(this)->ForEachTensor(
nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
@ -506,12 +506,11 @@ static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
ReadBatches(reader, batches, pool);
}
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
BlobReader& reader,
const LoaderArgs& loader,
const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
const LoaderArgs& loader,
const InferenceArgs& inference,
std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
std::vector<TensorToRead> tensors;

View File

@ -314,10 +314,8 @@ struct LayerWeightsPtrs {
// Holds layer-independent weight metadata and pointers plus per-layer
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
// TODO: move `gemma-inl.h` toward dispatch at each usage.
// TODO: rename to WeightsPtrs.
struct ModelWeightsPtrs {
explicit ModelWeightsPtrs(const ModelConfig& config)
struct WeightsPtrs {
explicit WeightsPtrs(const ModelConfig& config)
: config_(config),
tensors_(config_),
finder_("", tensors_), // no suffix because these are per-model.
@ -343,7 +341,7 @@ struct ModelWeightsPtrs {
}
}
~ModelWeightsPtrs() = default;
~WeightsPtrs() = default;
const ModelConfig& config_;
// Passed to finder_, hence must be initialized first.
@ -383,8 +381,7 @@ struct ModelWeightsPtrs {
// used to copy from another set of weights. Public because called by tests
// and `WeightsOwner`.
template <class Func>
void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2,
Func func) {
void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) {
LayerWeightsPtrs* other_layer1 = nullptr;
LayerWeightsPtrs* other_layer2 = nullptr;
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
@ -423,7 +420,7 @@ struct ModelWeightsPtrs {
// Zero-initializes only the allocated tensors in `*this`.
void ZeroInit();
// Copies only the allocated tensors in `*this` from tensors in `other`.
void CopyFrom(const ModelWeightsPtrs& other);
void CopyFrom(const WeightsPtrs& other);
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
// override for whether to map blobs or read them.
@ -438,7 +435,7 @@ struct ModelWeightsPtrs {
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by ReadFromBlobs.
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
}; // `ModelWeightsPtrs`
}; // `WeightsPtrs`
#undef TENSOR_ARGS
} // namespace gcpp