mirror of https://github.com/google/gemma.cpp.git
Further simplification to ForEachTensor, thanks I.K.
PiperOrigin-RevId: 643996210
This commit is contained in:
parent
7d0720675f
commit
704d936764
|
|
@ -318,8 +318,7 @@ void CompressWeights(const Path& weights_path,
|
||||||
WeightsF<TConfig>* weights =
|
WeightsF<TConfig>* weights =
|
||||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
Compressor compressor(pool);
|
Compressor compressor(pool);
|
||||||
ForEachTensor</*kHaveRaw=*/true, TConfig, LayerF<TConfig>>(
|
ForEachTensor<TConfig, LayerF<TConfig>>(weights, *c_weights, compressor);
|
||||||
weights, *c_weights, compressor);
|
|
||||||
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||||
compressor.WriteAll(pool, compressed_weights_path);
|
compressor.WriteAll(pool, compressed_weights_path);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,7 @@ struct LoadCompressedWeightsT {
|
||||||
|
|
||||||
std::array<float, TConfig::kNumTensorScales> scales;
|
std::array<float, TConfig::kNumTensorScales> scales;
|
||||||
CacheLoader loader(weights);
|
CacheLoader loader(weights);
|
||||||
const void* raw_weights = nullptr; // ForEachTensor requires const.
|
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
||||||
ForEachTensor</*kHaveRaw=*/false, TConfig>(raw_weights, *c_weights, loader);
|
|
||||||
loader.LoadScales(scales.data(), scales.size());
|
loader.LoadScales(scales.data(), scales.size());
|
||||||
if (!loader.ReadAll(pool)) {
|
if (!loader.ReadAll(pool)) {
|
||||||
HWY_ABORT("Failed to load model weights.");
|
HWY_ABORT("Failed to load model weights.");
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
|
@ -221,17 +223,17 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is
|
||||||
// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens
|
// null if raw_weights is nullptr, e.g., when loading weights from BlobStore.
|
||||||
// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be
|
// Otherwise, RawLayer must be specified and we pass a float* pointing to the
|
||||||
// specified and we pass a float* pointing to the raw float weights for that
|
// raw float weights for that tensor for use by compress_weights.cc.
|
||||||
// tensor for use by compress_weights.cc.
|
|
||||||
//
|
//
|
||||||
// This avoids repeating the list of tensors between loading and compressing,
|
// This avoids repeating the list of tensors between loading and compressing,
|
||||||
// while also avoiding dependency on raw_weights.h.
|
// while also avoiding dependency on raw_weights.h.
|
||||||
template <bool kHaveRaw, class TConfig, class RawLayer = void,
|
template <class TConfig, class RawLayer = void, class RawWeightsPtr, class Func>
|
||||||
class RawWeights = void, class Func>
|
void ForEachTensor(RawWeightsPtr raw_weights,
|
||||||
void ForEachTensor(const RawWeights* raw_weights,
|
|
||||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||||
|
constexpr bool kHaveRaw = !hwy::IsSame<RawWeightsPtr, nullptr_t>();
|
||||||
|
|
||||||
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
|
||||||
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue