mirror of https://github.com/google/gemma.cpp.git
Minor: ModelWeightsPtrs -> WeightsPtrs
PiperOrigin-RevId: 781954533
This commit is contained in:
parent
fea9a07d9b
commit
4bc44d5678
|
|
@ -137,7 +137,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
|||
// Returns new image_token_position.
|
||||
static HWY_NOINLINE size_t
|
||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
|
||||
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||
size_t image_token_position = 0) {
|
||||
// Image tokens just need to be copied.
|
||||
|
|
@ -187,7 +187,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
|||
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
||||
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
|
|
@ -291,7 +291,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
|||
// token-batched `PrefillTBatch`.
|
||||
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env) {
|
||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||
|
|
@ -324,7 +324,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
|||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
|
|
@ -391,7 +391,7 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
|||
// streams the token.
|
||||
static void DecodeStepT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const WeightsPtrs& weights,
|
||||
const SampleFunc& sample_token,
|
||||
Activations& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
|
|
@ -451,7 +451,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
|||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const WeightsPtrs& weights, Activations& activations,
|
||||
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
|
|
@ -534,7 +534,7 @@ static void GenerateT(const ModelConfig& config,
|
|||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, KVCache& kv_cache,
|
||||
const WeightsPtrs& weights, KVCache& kv_cache,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||
kv_cache.SeqLen(), env.row_ptrs);
|
||||
|
|
@ -550,7 +550,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
|||
// queries, and calls `GenerateT` on each batch.
|
||||
void GenerateBatchT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, AllQueries& all_queries,
|
||||
const WeightsPtrs& weights, AllQueries& all_queries,
|
||||
MatMulEnv& env, TimingInfo& timing_info) {
|
||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||
runtime_config.prefill_tbatch_size);
|
||||
|
|
@ -568,7 +568,7 @@ void GenerateBatchT(const ModelConfig& config,
|
|||
|
||||
void GenerateImageTokensT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, size_t seq_len,
|
||||
const ModelWeightsPtrs& weights, const Image& image,
|
||||
const WeightsPtrs& weights, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||
if (config.vit_config.layer_configs.empty()) {
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class Gemma {
|
|||
// TODO: rename to Config()
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
const ModelWeightsPtrs& Weights() const { return weights_; }
|
||||
const WeightsPtrs& Weights() const { return weights_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const InferenceArgs& Inference() const { return inference_; }
|
||||
|
||||
|
|
@ -274,7 +274,7 @@ class Gemma {
|
|||
BlobReader reader_;
|
||||
ModelStore model_;
|
||||
std::vector<MatOwner> mat_owners_;
|
||||
ModelWeightsPtrs weights_;
|
||||
WeightsPtrs weights_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
InferenceArgs inference_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ TEST(TensorInfoRegistryTest, Find) {
|
|||
config.Specifier().c_str());
|
||||
const TensorInfoRegistry tensors(config);
|
||||
// Each tensor in the model should be known/found.
|
||||
ModelWeightsPtrs weights(config);
|
||||
WeightsPtrs weights(config);
|
||||
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
|
||||
const TensorInfo* info = tensors.Find(t.mat.Name());
|
||||
HWY_ASSERT_M(info, t.mat.Name());
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|||
// kernel. The result is stored in activations.x.
|
||||
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||
const ModelConfig& model_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
const WeightsPtrs& weights,
|
||||
Activations& activations,
|
||||
MatMulEnv& env) {
|
||||
const size_t model_dim = model_config.vit_config.model_dim;
|
||||
|
|
@ -308,8 +308,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
}
|
||||
|
||||
// Prefills the image tokens with the ViT encoder.
|
||||
void PrefillVit(const ModelConfig& model_config,
|
||||
const ModelWeightsPtrs& weights,
|
||||
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, Activations& activations,
|
||||
MatMulEnv& env) {
|
||||
|
|
|
|||
|
|
@ -31,8 +31,7 @@ namespace gcpp {
|
|||
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
|
||||
MatMulEnv& env); \
|
||||
\
|
||||
void PrefillVit(const ModelConfig& model_config, \
|
||||
const ModelWeightsPtrs& weights, \
|
||||
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, Activations& activations, \
|
||||
MatMulEnv& env); \
|
||||
|
|
|
|||
|
|
@ -203,7 +203,7 @@ static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) {
|
|||
}
|
||||
|
||||
// Zero-initializes only the allocated tensors in `*this`.
|
||||
void ModelWeightsPtrs::ZeroInit() {
|
||||
void WeightsPtrs::ZeroInit() {
|
||||
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
gcpp::ZeroInit(t.mat);
|
||||
|
|
@ -211,8 +211,8 @@ void ModelWeightsPtrs::ZeroInit() {
|
|||
}
|
||||
|
||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||
void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
|
||||
ForEachTensor(const_cast<ModelWeightsPtrs*>(&other), nullptr,
|
||||
void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
||||
ForEachTensor(const_cast<WeightsPtrs*>(&other), nullptr,
|
||||
[](const TensorArgs& t) {
|
||||
if (!t.mat.HasPtr()) return;
|
||||
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
||||
|
|
@ -222,7 +222,7 @@ void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
|
|||
|
||||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
|
||||
void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||
hwy::ThreadPool& pool) {
|
||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Fixup(mat_owners);
|
||||
|
|
@ -233,11 +233,11 @@ void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
|||
});
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
|
||||
std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter(
|
||||
BlobWriter& writer) const {
|
||||
std::vector<uint32_t> serialized_mat_ptrs;
|
||||
// ForEachTensor is non-const but the lambda does not modify *this.
|
||||
const_cast<ModelWeightsPtrs*>(this)->ForEachTensor(
|
||||
const_cast<WeightsPtrs*>(this)->ForEachTensor(
|
||||
nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
||||
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
||||
|
|
@ -506,8 +506,7 @@ static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
|||
ReadBatches(reader, batches, pool);
|
||||
}
|
||||
|
||||
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
||||
BlobReader& reader,
|
||||
void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||
const LoaderArgs& loader,
|
||||
const InferenceArgs& inference,
|
||||
std::vector<MatOwner>& mat_owners,
|
||||
|
|
|
|||
|
|
@ -314,10 +314,8 @@ struct LayerWeightsPtrs {
|
|||
|
||||
// Holds layer-independent weight metadata and pointers plus per-layer
|
||||
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
|
||||
// TODO: move `gemma-inl.h` toward dispatch at each usage.
|
||||
// TODO: rename to WeightsPtrs.
|
||||
struct ModelWeightsPtrs {
|
||||
explicit ModelWeightsPtrs(const ModelConfig& config)
|
||||
struct WeightsPtrs {
|
||||
explicit WeightsPtrs(const ModelConfig& config)
|
||||
: config_(config),
|
||||
tensors_(config_),
|
||||
finder_("", tensors_), // no suffix because these are per-model.
|
||||
|
|
@ -343,7 +341,7 @@ struct ModelWeightsPtrs {
|
|||
}
|
||||
}
|
||||
|
||||
~ModelWeightsPtrs() = default;
|
||||
~WeightsPtrs() = default;
|
||||
|
||||
const ModelConfig& config_;
|
||||
// Passed to finder_, hence must be initialized first.
|
||||
|
|
@ -383,8 +381,7 @@ struct ModelWeightsPtrs {
|
|||
// used to copy from another set of weights. Public because called by tests
|
||||
// and `WeightsOwner`.
|
||||
template <class Func>
|
||||
void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2,
|
||||
Func func) {
|
||||
void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) {
|
||||
LayerWeightsPtrs* other_layer1 = nullptr;
|
||||
LayerWeightsPtrs* other_layer2 = nullptr;
|
||||
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
||||
|
|
@ -423,7 +420,7 @@ struct ModelWeightsPtrs {
|
|||
// Zero-initializes only the allocated tensors in `*this`.
|
||||
void ZeroInit();
|
||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||
void CopyFrom(const ModelWeightsPtrs& other);
|
||||
void CopyFrom(const WeightsPtrs& other);
|
||||
|
||||
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
||||
// override for whether to map blobs or read them.
|
||||
|
|
@ -438,7 +435,7 @@ struct ModelWeightsPtrs {
|
|||
// For reshaping file tensors to the shape expected by the code. This would
|
||||
// ideally already happen in the importer. Called by ReadFromBlobs.
|
||||
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
||||
}; // `ModelWeightsPtrs`
|
||||
}; // `WeightsPtrs`
|
||||
#undef TENSOR_ARGS
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue