diff --git a/gemma/compress_weights.cc b/gemma/compress_weights.cc index bff3460..aaaa6ab 100644 --- a/gemma/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -318,8 +318,7 @@ void CompressWeights(const Path& weights_path, WeightsF* weights = reinterpret_cast*>(weights_u8.get()); Compressor compressor(pool); - ForEachTensor>( - weights, *c_weights, compressor); + ForEachTensor>(weights, *c_weights, compressor); compressor.AddScales(weights->scales.data(), weights->scales.size()); compressor.WriteAll(pool, compressed_weights_path); diff --git a/gemma/weights.cc b/gemma/weights.cc index 660576a..aeb6cab 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -46,8 +46,7 @@ struct LoadCompressedWeightsT { std::array scales; CacheLoader loader(weights); - const void* raw_weights = nullptr; // ForEachTensor requires const. - ForEachTensor(raw_weights, *c_weights, loader); + ForEachTensor(nullptr, *c_weights, loader); loader.LoadScales(scales.data(), scales.size()); if (!loader.ReadAll(pool)) { HWY_ABORT("Failed to load model weights."); diff --git a/gemma/weights.h b/gemma/weights.h index 5192319..d69f61a 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -16,6 +16,8 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ +#include + #include "compression/compress.h" #include "gemma/common.h" #include "gemma/configs.h" @@ -221,17 +223,17 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is -// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens -// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be -// specified and we pass a float* pointing to the raw float weights for that -// tensor for use by compress_weights.cc. +// null if raw_weights is nullptr, e.g., when loading weights from BlobStore. +// Otherwise, RawLayer must be specified and we pass a float* pointing to the +// raw float weights for that tensor for use by compress_weights.cc. // // This avoids repeating the list of tensors between loading and compressing, // while also avoiding dependency on raw_weights.h. -template -void ForEachTensor(const RawWeights* raw_weights, +template +void ForEachTensor(RawWeightsPtr raw_weights, CompressedWeights& c_weights, Func& func) { + constexpr bool kHaveRaw = !hwy::IsSame(); + GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);