Further simplification to ForEachTensor, thanks I.K.

PiperOrigin-RevId: 643996210
This commit is contained in:
Jan Wassenberg 2024-06-17 07:11:57 -07:00 committed by Copybara-Service
parent 7d0720675f
commit 704d936764
3 changed files with 11 additions and 11 deletions

View File

@ -318,8 +318,7 @@ void CompressWeights(const Path& weights_path,
WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool);
ForEachTensor</*kHaveRaw=*/true, TConfig, LayerF<TConfig>>(
weights, *c_weights, compressor);
ForEachTensor<TConfig, LayerF<TConfig>>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path);

View File

@ -46,8 +46,7 @@ struct LoadCompressedWeightsT {
std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(weights);
const void* raw_weights = nullptr; // ForEachTensor requires const.
ForEachTensor</*kHaveRaw=*/false, TConfig>(raw_weights, *c_weights, loader);
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
loader.LoadScales(scales.data(), scales.size());
if (!loader.ReadAll(pool)) {
HWY_ABORT("Failed to load model weights.");

View File

@ -16,6 +16,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#include <stddef.h>
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/configs.h"
@ -221,17 +223,17 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
}
// Calls func(name, float*, CompressedArray&) for each tensor. float* is
// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens
// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be
// specified and we pass a float* pointing to the raw float weights for that
// tensor for use by compress_weights.cc.
// null if raw_weights is nullptr, e.g., when loading weights from BlobStore.
// Otherwise, RawLayer must be specified and we pass a float* pointing to the
// raw float weights for that tensor for use by compress_weights.cc.
//
// This avoids repeating the list of tensors between loading and compressing,
// while also avoiding dependency on raw_weights.h.
template <bool kHaveRaw, class TConfig, class RawLayer = void,
class RawWeights = void, class Func>
void ForEachTensor(const RawWeights* raw_weights,
template <class TConfig, class RawLayer = void, class RawWeightsPtr, class Func>
void ForEachTensor(RawWeightsPtr raw_weights,
CompressedWeights<TConfig>& c_weights, Func& func) {
constexpr bool kHaveRaw = !hwy::IsSame<RawWeightsPtr, nullptr_t>();
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);