mirror of https://github.com/google/gemma.cpp.git
Minor: ModelWeightsPtrs -> WeightsPtrs
PiperOrigin-RevId: 781954533
This commit is contained in:
parent
fea9a07d9b
commit
4bc44d5678
|
|
@ -137,7 +137,7 @@ static float EmbeddingScaling(size_t model_dim) {
|
||||||
// Returns new image_token_position.
|
// Returns new image_token_position.
|
||||||
static HWY_NOINLINE size_t
|
static HWY_NOINLINE size_t
|
||||||
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
|
const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
|
||||||
size_t image_token_position = 0) {
|
size_t image_token_position = 0) {
|
||||||
// Image tokens just need to be copied.
|
// Image tokens just need to be copied.
|
||||||
|
|
@ -187,7 +187,7 @@ EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt,
|
||||||
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
// prefix-LM mode (end > 0), which must see all tokens in one batch.
|
||||||
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations, QBatch& qbatch,
|
Activations& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env,
|
MatMulEnv& env,
|
||||||
hwy::BitSet4096<>& non_eos) {
|
hwy::BitSet4096<>& non_eos) {
|
||||||
|
|
@ -291,7 +291,7 @@ static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config,
|
||||||
// token-batched `PrefillTBatch`.
|
// token-batched `PrefillTBatch`.
|
||||||
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations, QBatch& qbatch,
|
Activations& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
if (HWY_UNLIKELY(runtime_config.layers_output)) {
|
||||||
|
|
@ -324,7 +324,7 @@ static HWY_NOINLINE void Transformer(const ModelConfig& config,
|
||||||
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size,
|
||||||
const ModelConfig& config,
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations, QBatch& qbatch,
|
Activations& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env,
|
MatMulEnv& env,
|
||||||
hwy::BitSet4096<>& non_eos) {
|
hwy::BitSet4096<>& non_eos) {
|
||||||
|
|
@ -391,7 +391,7 @@ static void StreamAndUpdateEOS(const size_t qi, int token, const float prob,
|
||||||
// streams the token.
|
// streams the token.
|
||||||
static void DecodeStepT(const ModelConfig& config,
|
static void DecodeStepT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
const SampleFunc& sample_token,
|
const SampleFunc& sample_token,
|
||||||
Activations& activations, QBatch& qbatch,
|
Activations& activations, QBatch& qbatch,
|
||||||
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||||
|
|
@ -451,7 +451,7 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
// Decode: generates one continuation token for each query in `qbatch`.
|
// Decode: generates one continuation token for each query in `qbatch`.
|
||||||
static void GenerateT(const ModelConfig& config,
|
static void GenerateT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights, Activations& activations,
|
const WeightsPtrs& weights, Activations& activations,
|
||||||
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
// Griffin assumes that the recurrent block cache is zero-initialized.
|
// Griffin assumes that the recurrent block cache is zero-initialized.
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
|
@ -534,7 +534,7 @@ static void GenerateT(const ModelConfig& config,
|
||||||
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
const ModelConfig& config,
|
const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights, KVCache& kv_cache,
|
const WeightsPtrs& weights, KVCache& kv_cache,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
Activations activations(config, runtime_config.prefill_tbatch_size,
|
Activations activations(config, runtime_config.prefill_tbatch_size,
|
||||||
kv_cache.SeqLen(), env.row_ptrs);
|
kv_cache.SeqLen(), env.row_ptrs);
|
||||||
|
|
@ -550,7 +550,7 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
// queries, and calls `GenerateT` on each batch.
|
// queries, and calls `GenerateT` on each batch.
|
||||||
void GenerateBatchT(const ModelConfig& config,
|
void GenerateBatchT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const ModelWeightsPtrs& weights, AllQueries& all_queries,
|
const WeightsPtrs& weights, AllQueries& all_queries,
|
||||||
MatMulEnv& env, TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size,
|
||||||
runtime_config.prefill_tbatch_size);
|
runtime_config.prefill_tbatch_size);
|
||||||
|
|
@ -568,7 +568,7 @@ void GenerateBatchT(const ModelConfig& config,
|
||||||
|
|
||||||
void GenerateImageTokensT(const ModelConfig& config,
|
void GenerateImageTokensT(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config, size_t seq_len,
|
const RuntimeConfig& runtime_config, size_t seq_len,
|
||||||
const ModelWeightsPtrs& weights, const Image& image,
|
const WeightsPtrs& weights, const Image& image,
|
||||||
ImageTokens& image_tokens, MatMulEnv& env) {
|
ImageTokens& image_tokens, MatMulEnv& env) {
|
||||||
if (config.vit_config.layer_configs.empty()) {
|
if (config.vit_config.layer_configs.empty()) {
|
||||||
HWY_ABORT("Model does not support generating image tokens.");
|
HWY_ABORT("Model does not support generating image tokens.");
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,7 @@ class Gemma {
|
||||||
// TODO: rename to Config()
|
// TODO: rename to Config()
|
||||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||||
const ModelWeightsPtrs& Weights() const { return weights_; }
|
const WeightsPtrs& Weights() const { return weights_; }
|
||||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||||
const InferenceArgs& Inference() const { return inference_; }
|
const InferenceArgs& Inference() const { return inference_; }
|
||||||
|
|
||||||
|
|
@ -274,7 +274,7 @@ class Gemma {
|
||||||
BlobReader reader_;
|
BlobReader reader_;
|
||||||
ModelStore model_;
|
ModelStore model_;
|
||||||
std::vector<MatOwner> mat_owners_;
|
std::vector<MatOwner> mat_owners_;
|
||||||
ModelWeightsPtrs weights_;
|
WeightsPtrs weights_;
|
||||||
GemmaChatTemplate chat_template_;
|
GemmaChatTemplate chat_template_;
|
||||||
InferenceArgs inference_;
|
InferenceArgs inference_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ TEST(TensorInfoRegistryTest, Find) {
|
||||||
config.Specifier().c_str());
|
config.Specifier().c_str());
|
||||||
const TensorInfoRegistry tensors(config);
|
const TensorInfoRegistry tensors(config);
|
||||||
// Each tensor in the model should be known/found.
|
// Each tensor in the model should be known/found.
|
||||||
ModelWeightsPtrs weights(config);
|
WeightsPtrs weights(config);
|
||||||
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
|
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
|
||||||
const TensorInfo* info = tensors.Find(t.mat.Name());
|
const TensorInfo* info = tensors.Find(t.mat.Name());
|
||||||
HWY_ASSERT_M(info, t.mat.Name());
|
HWY_ASSERT_M(info, t.mat.Name());
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
||||||
// kernel. The result is stored in activations.x.
|
// kernel. The result is stored in activations.x.
|
||||||
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
const ModelConfig& model_config,
|
const ModelConfig& model_config,
|
||||||
const ModelWeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations,
|
Activations& activations,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
const size_t model_dim = model_config.vit_config.model_dim;
|
const size_t model_dim = model_config.vit_config.model_dim;
|
||||||
|
|
@ -308,8 +308,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefills the image tokens with the ViT encoder.
|
// Prefills the image tokens with the ViT encoder.
|
||||||
void PrefillVit(const ModelConfig& model_config,
|
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
||||||
const ModelWeightsPtrs& weights,
|
|
||||||
const RuntimeConfig& runtime_config, const Image& image,
|
const RuntimeConfig& runtime_config, const Image& image,
|
||||||
ImageTokens& image_tokens, Activations& activations,
|
ImageTokens& image_tokens, Activations& activations,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
|
|
|
||||||
21
gemma/vit.h
21
gemma/vit.h
|
|
@ -26,17 +26,16 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||||
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
|
#define GEMMA_DECL_VIT(TARGET, NAMESPACE) \
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
|
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, \
|
||||||
MatMulEnv& env); \
|
MatMulEnv& env); \
|
||||||
\
|
\
|
||||||
void PrefillVit(const ModelConfig& model_config, \
|
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights, \
|
||||||
const ModelWeightsPtrs& weights, \
|
const RuntimeConfig& runtime_config, const Image& image, \
|
||||||
const RuntimeConfig& runtime_config, const Image& image, \
|
ImageTokens& image_tokens, Activations& activations, \
|
||||||
ImageTokens& image_tokens, Activations& activations, \
|
MatMulEnv& env); \
|
||||||
MatMulEnv& env); \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
// Function declarations for each SIMD target. Allows direct call from the
|
// Function declarations for each SIMD target. Allows direct call from the
|
||||||
|
|
|
||||||
|
|
@ -203,7 +203,7 @@ static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Zero-initializes only the allocated tensors in `*this`.
|
// Zero-initializes only the allocated tensors in `*this`.
|
||||||
void ModelWeightsPtrs::ZeroInit() {
|
void WeightsPtrs::ZeroInit() {
|
||||||
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
|
||||||
if (!t.mat.HasPtr()) return;
|
if (!t.mat.HasPtr()) return;
|
||||||
gcpp::ZeroInit(t.mat);
|
gcpp::ZeroInit(t.mat);
|
||||||
|
|
@ -211,8 +211,8 @@ void ModelWeightsPtrs::ZeroInit() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||||
void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
|
void WeightsPtrs::CopyFrom(const WeightsPtrs& other) {
|
||||||
ForEachTensor(const_cast<ModelWeightsPtrs*>(&other), nullptr,
|
ForEachTensor(const_cast<WeightsPtrs*>(&other), nullptr,
|
||||||
[](const TensorArgs& t) {
|
[](const TensorArgs& t) {
|
||||||
if (!t.mat.HasPtr()) return;
|
if (!t.mat.HasPtr()) return;
|
||||||
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
|
||||||
|
|
@ -222,8 +222,8 @@ void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
|
||||||
|
|
||||||
// For reshaping file tensors to the shape expected by the code. This would
|
// For reshaping file tensors to the shape expected by the code. This would
|
||||||
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
|
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
|
||||||
void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
void WeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||||
GetLayer(layer)->Fixup(mat_owners);
|
GetLayer(layer)->Fixup(mat_owners);
|
||||||
});
|
});
|
||||||
|
|
@ -233,11 +233,11 @@ void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
|
std::vector<uint32_t> WeightsPtrs::AddTensorDataToWriter(
|
||||||
BlobWriter& writer) const {
|
BlobWriter& writer) const {
|
||||||
std::vector<uint32_t> serialized_mat_ptrs;
|
std::vector<uint32_t> serialized_mat_ptrs;
|
||||||
// ForEachTensor is non-const but the lambda does not modify *this.
|
// ForEachTensor is non-const but the lambda does not modify *this.
|
||||||
const_cast<ModelWeightsPtrs*>(this)->ForEachTensor(
|
const_cast<WeightsPtrs*>(this)->ForEachTensor(
|
||||||
nullptr, nullptr, [&](const TensorArgs& t) {
|
nullptr, nullptr, [&](const TensorArgs& t) {
|
||||||
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
||||||
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
||||||
|
|
@ -506,12 +506,11 @@ static void MapOrReadAll(std::vector<TensorToRead>& tensors, BlobReader& reader,
|
||||||
ReadBatches(reader, batches, pool);
|
ReadBatches(reader, batches, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
|
void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
BlobReader& reader,
|
const LoaderArgs& loader,
|
||||||
const LoaderArgs& loader,
|
const InferenceArgs& inference,
|
||||||
const InferenceArgs& inference,
|
std::vector<MatOwner>& mat_owners,
|
||||||
std::vector<MatOwner>& mat_owners,
|
hwy::ThreadPool& pool) {
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
std::vector<TensorToRead> tensors;
|
std::vector<TensorToRead> tensors;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -314,10 +314,8 @@ struct LayerWeightsPtrs {
|
||||||
|
|
||||||
// Holds layer-independent weight metadata and pointers plus per-layer
|
// Holds layer-independent weight metadata and pointers plus per-layer
|
||||||
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
|
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
|
||||||
// TODO: move `gemma-inl.h` toward dispatch at each usage.
|
struct WeightsPtrs {
|
||||||
// TODO: rename to WeightsPtrs.
|
explicit WeightsPtrs(const ModelConfig& config)
|
||||||
struct ModelWeightsPtrs {
|
|
||||||
explicit ModelWeightsPtrs(const ModelConfig& config)
|
|
||||||
: config_(config),
|
: config_(config),
|
||||||
tensors_(config_),
|
tensors_(config_),
|
||||||
finder_("", tensors_), // no suffix because these are per-model.
|
finder_("", tensors_), // no suffix because these are per-model.
|
||||||
|
|
@ -343,7 +341,7 @@ struct ModelWeightsPtrs {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~ModelWeightsPtrs() = default;
|
~WeightsPtrs() = default;
|
||||||
|
|
||||||
const ModelConfig& config_;
|
const ModelConfig& config_;
|
||||||
// Passed to finder_, hence must be initialized first.
|
// Passed to finder_, hence must be initialized first.
|
||||||
|
|
@ -383,8 +381,7 @@ struct ModelWeightsPtrs {
|
||||||
// used to copy from another set of weights. Public because called by tests
|
// used to copy from another set of weights. Public because called by tests
|
||||||
// and `WeightsOwner`.
|
// and `WeightsOwner`.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2,
|
void ForEachTensor(WeightsPtrs* other1, WeightsPtrs* other2, Func func) {
|
||||||
Func func) {
|
|
||||||
LayerWeightsPtrs* other_layer1 = nullptr;
|
LayerWeightsPtrs* other_layer1 = nullptr;
|
||||||
LayerWeightsPtrs* other_layer2 = nullptr;
|
LayerWeightsPtrs* other_layer2 = nullptr;
|
||||||
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
|
||||||
|
|
@ -423,7 +420,7 @@ struct ModelWeightsPtrs {
|
||||||
// Zero-initializes only the allocated tensors in `*this`.
|
// Zero-initializes only the allocated tensors in `*this`.
|
||||||
void ZeroInit();
|
void ZeroInit();
|
||||||
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
// Copies only the allocated tensors in `*this` from tensors in `other`.
|
||||||
void CopyFrom(const ModelWeightsPtrs& other);
|
void CopyFrom(const WeightsPtrs& other);
|
||||||
|
|
||||||
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
|
||||||
// override for whether to map blobs or read them.
|
// override for whether to map blobs or read them.
|
||||||
|
|
@ -438,7 +435,7 @@ struct ModelWeightsPtrs {
|
||||||
// For reshaping file tensors to the shape expected by the code. This would
|
// For reshaping file tensors to the shape expected by the code. This would
|
||||||
// ideally already happen in the importer. Called by ReadFromBlobs.
|
// ideally already happen in the importer. Called by ReadFromBlobs.
|
||||||
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
|
||||||
}; // `ModelWeightsPtrs`
|
}; // `WeightsPtrs`
|
||||||
#undef TENSOR_ARGS
|
#undef TENSOR_ARGS
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue