mirror of https://github.com/google/gemma.cpp.git
Further simplification to ForEachTensor, thanks I.K.
PiperOrigin-RevId: 643996210
This commit is contained in:
parent
7d0720675f
commit
704d936764
|
|
@ -318,8 +318,7 @@ void CompressWeights(const Path& weights_path,
|
|||
WeightsF<TConfig>* weights =
|
||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor</*kHaveRaw=*/true, TConfig, LayerF<TConfig>>(
|
||||
weights, *c_weights, compressor);
|
||||
ForEachTensor<TConfig, LayerF<TConfig>>(weights, *c_weights, compressor);
|
||||
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||
compressor.WriteAll(pool, compressed_weights_path);
|
||||
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ struct LoadCompressedWeightsT {
|
|||
|
||||
std::array<float, TConfig::kNumTensorScales> scales;
|
||||
CacheLoader loader(weights);
|
||||
const void* raw_weights = nullptr; // ForEachTensor requires const.
|
||||
ForEachTensor</*kHaveRaw=*/false, TConfig>(raw_weights, *c_weights, loader);
|
||||
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
||||
loader.LoadScales(scales.data(), scales.size());
|
||||
if (!loader.ReadAll(pool)) {
|
||||
HWY_ABORT("Failed to load model weights.");
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
|
@ -221,17 +223,17 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
|||
}
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is
|
||||
// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens
|
||||
// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be
|
||||
// specified and we pass a float* pointing to the raw float weights for that
|
||||
// tensor for use by compress_weights.cc.
|
||||
// null if raw_weights is nullptr, e.g., when loading weights from BlobStore.
|
||||
// Otherwise, RawLayer must be specified and we pass a float* pointing to the
|
||||
// raw float weights for that tensor for use by compress_weights.cc.
|
||||
//
|
||||
// This avoids repeating the list of tensors between loading and compressing,
|
||||
// while also avoiding dependency on raw_weights.h.
|
||||
template <bool kHaveRaw, class TConfig, class RawLayer = void,
|
||||
class RawWeights = void, class Func>
|
||||
void ForEachTensor(const RawWeights* raw_weights,
|
||||
template <class TConfig, class RawLayer = void, class RawWeightsPtr, class Func>
|
||||
void ForEachTensor(RawWeightsPtr raw_weights,
|
||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||
constexpr bool kHaveRaw = !hwy::IsSame<RawWeightsPtr, nullptr_t>();
|
||||
|
||||
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
|
||||
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue