From 7d0720675f99f41ca884ef9b2cc1331a5b6a07b7 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 17 Jun 2024 06:16:17 -0700 Subject: [PATCH] Move raw_weights into separate header, used mainly by compress_weights. Fix warnings in backprop/* (include) PiperOrigin-RevId: 643983136 --- BUILD.bazel | 21 ++- CMakeLists.txt | 1 + backprop/backward-inl.h | 1 - backprop/backward.h | 2 - backprop/backward_scalar.h | 3 +- backprop/backward_scalar_test.cc | 12 +- backprop/backward_test.cc | 9 +- backprop/forward-inl.h | 9 +- backprop/forward_scalar.h | 2 +- backprop/test_util.h | 33 +--- gemma/compress_weights.cc | 4 +- gemma/gemma.cc | 5 +- gemma/weights.cc | 4 +- gemma/weights.h | 284 ++++++++----------------------- gemma/weights_raw.h | 242 ++++++++++++++++++++++++++ 15 files changed, 353 insertions(+), 279 deletions(-) create mode 100644 gemma/weights_raw.h diff --git a/BUILD.bazel b/BUILD.bazel index fc64dc5..42bfeea 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -83,6 +83,18 @@ cc_library( ], ) +cc_library( + name = "weights_raw", + hdrs = ["gemma/weights_raw.h"], + deps = [ + ":common", + ":weights", + "//compression:compress", + "@hwy//:hwy", + "@hwy//:thread_pool", + ], +) + cc_library( name = "gemma_lib", srcs = [ @@ -214,6 +226,7 @@ cc_binary( ":common", ":gemma_lib", ":weights", + ":weights_raw", # Placeholder for internal dep, do not remove., "//compression:compress", "@hwy//:hwy", @@ -331,7 +344,7 @@ cc_library( ":common", ":gemma_lib", ":prompt", - ":weights", + ":weights_raw", ], ) @@ -346,7 +359,7 @@ cc_test( ":backprop_scalar", ":prompt", ":sampler", - ":weights", + ":weights_raw", "@googletest//:gtest_main", ], ) @@ -363,11 +376,9 @@ cc_test( ":backprop_scalar", ":gemma_lib", ":ops", - ":prompt", ":sampler", - ":weights", + ":weights_raw", "@googletest//:gtest_main", - "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:thread_pool", diff --git a/CMakeLists.txt b/CMakeLists.txt index 1473d3b..ae4fb73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,7 @@ set(SOURCES gemma/ops.h gemma/weights.cc gemma/weights.h + gemma/weights_raw.h util/app.h util/args.h ) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index d6c4d68..837dc13 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -23,7 +23,6 @@ #include #include -#include #include #include "backprop/prompt.h" diff --git a/backprop/backward.h b/backprop/backward.h index 6917f20..aac2122 100644 --- a/backprop/backward.h +++ b/backprop/backward.h @@ -16,8 +16,6 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ -#include - #include "backprop/prompt.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index 024e2a4..aa652ac 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -20,14 +20,13 @@ #include #include -#include #include #include "backprop/common_scalar.h" #include "backprop/prompt.h" #include "gemma/activations.h" #include "gemma/common.h" // EmbeddingScaling -#include "gemma/weights.h" +#include "gemma/weights_raw.h" namespace gcpp { template diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 9a94484..2a9d99b 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -15,6 +15,9 @@ #include "backprop/backward_scalar.h" +#include +#include // memset + #include #include #include @@ -23,6 +26,7 @@ #include "backprop/forward_scalar.h" #include "backprop/sampler.h" #include "backprop/test_util.h" +#include "gemma/weights_raw.h" namespace gcpp { @@ -55,8 +59,8 @@ TEST(BackPropTest, MatMulVJP) { memset(&grad, 0, sizeof(grad)); MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), kRows, kCols, kTokens); - TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__); - TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__); + TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__); + TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__); } } @@ -91,8 +95,8 @@ TEST(BackPropTest, MultiHeadMatMulVJP) { memset(&grad, 0, sizeof(grad)); MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), kHeads, kRows, kCols, kTokens); - TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__); - TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__); + TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__); + TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__); } } diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 4a3e4cc..146ce67 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -19,7 +19,6 @@ #include -#include #include #include #include @@ -29,10 +28,8 @@ #include "backprop/forward_scalar.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "compression/compress.h" -#include "gemma/gemma.h" -#include "gemma/weights.h" -#include "hwy/aligned_allocator.h" +#include "gemma/activations.h" +#include "gemma/weights_raw.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -52,8 +49,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - void TestMatMulVJP() { static const size_t kRows = 8; static const size_t kCols = 64; diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 49efedd..b322061 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -20,8 +20,8 @@ #include #include -#include #include +#include #include "gemma/activations.h" #include "gemma/common.h" @@ -40,11 +40,11 @@ #endif #include "gemma/ops.h" +#include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; template void InputEmbedding(const ArrayT& weights, const std::vector& prompt, @@ -202,11 +202,10 @@ void ApplyForwardLayer(const LayerT& weights, activations.ffw_hidden_gated.data() + pos * kFFHiddenDim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; - using VF = hn::Vec; DF df; for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { - const auto y = Load(df, out + i); - const auto x = Load(df, out_mul + i); + const auto y = hn::Load(df, out + i); + const auto x = hn::Load(df, out_mul + i); hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i); } } diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 8bc125e..95c5f0c 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -27,7 +27,7 @@ #include "backprop/prompt.h" #include "gemma/activations.h" #include "gemma/common.h" // EmbeddingScaling -#include "gemma/weights.h" +#include "gemma/weights_raw.h" namespace gcpp { diff --git a/backprop/test_util.h b/backprop/test_util.h index 939411d..387b979 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -16,43 +16,16 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ +#include + #include #include -#include -#include "gemma/weights.h" #include "gtest/gtest.h" +#include "gemma/weights_raw.h" namespace gcpp { -template -void RandInit(std::array& x, T stddev, std::mt19937& gen) { - std::normal_distribution dist(0.0, stddev); - for (size_t i = 0; i < kLen; ++i) { - x[i] = dist(gen); - } -} - -template -void RandInit(Layer& w, T stddev, std::mt19937& gen) { - RandInit(w.pre_attention_norm_scale, stddev, gen); - RandInit(w.attn_vec_einsum_w, stddev, gen); - RandInit(w.qkv_einsum_w, stddev, gen); - RandInit(w.pre_ffw_norm_scale, stddev, gen); - RandInit(w.gating_einsum_w, stddev, gen); - RandInit(w.linear_w, stddev, gen); -} - -template -void RandInit(Weights& w, T stddev, std::mt19937& gen) { - static constexpr size_t kLayers = TConfig::kLayers; - RandInit(w.embedder_input_embedding, stddev, gen); - RandInit(w.final_norm_scale, stddev, gen); - for (size_t i = 0; i < kLayers; ++i) { - RandInit(*w.GetLayer(i), stddev, gen); - } -} - template void Complexify(const std::array& x, std::array, kLen>& c_x) { diff --git a/gemma/compress_weights.cc b/gemma/compress_weights.cc index 8552802..bff3460 100644 --- a/gemma/compress_weights.cc +++ b/gemma/compress_weights.cc @@ -40,6 +40,7 @@ #include "compression/io.h" // Path #include "gemma/common.h" // Model #include "gemma/weights.h" +#include "gemma/weights_raw.h" #include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -317,7 +318,8 @@ void CompressWeights(const Path& weights_path, WeightsF* weights = reinterpret_cast*>(weights_u8.get()); Compressor compressor(pool); - ForEachTensor(weights, *c_weights, compressor); + ForEachTensor>( + weights, *c_weights, compressor); compressor.AddScales(weights->scales.data(), weights->scales.size()); compressor.WriteAll(pool, compressed_weights_path); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 1716399..347a4cd 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -66,7 +66,6 @@ constexpr bool kShowTokenization = false; // Must be aligned. template struct Activations { - using LayerConfig = LayerF; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; @@ -979,8 +978,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, } Gemma::~Gemma() { - CallForModelAndWeight(model_type_, weight_type_, - weights_u8_); + CallForModelAndWeight(model_type_, weight_type_, + weights_u8_); } void Gemma::Generate(const RuntimeConfig& runtime_config, diff --git a/gemma/weights.cc b/gemma/weights.cc index d850cfd..660576a 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -15,7 +15,6 @@ #include "gemma/weights.h" -#include #include #include "compression/compress.h" @@ -47,7 +46,8 @@ struct LoadCompressedWeightsT { std::array scales; CacheLoader loader(weights); - ForEachTensor(nullptr, *c_weights, loader); + const void* raw_weights = nullptr; // ForEachTensor requires const. + ForEachTensor(raw_weights, *c_weights, loader); loader.LoadScales(scales.data(), scales.size()); if (!loader.ReadAll(pool)) { HWY_ABORT("Failed to load model weights."); diff --git a/gemma/weights.h b/gemma/weights.h index 8d21cf3..5192319 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -25,12 +25,17 @@ namespace gcpp { -// ---------------------------------------------------------------------------- -// Uncompressed +template +struct CompressedLayer { + // No ctor/dtor, allocated via AllocateAligned. + + using Weight = typename TConfig::Weight; + // If weights are f32, also f32; otherwise at least bf16. Useful for ops that + // do not yet support smaller compressed types, or require at least bf16. When + // weights are f32, we also want such tensors to be f32. + using WeightF32OrBF16 = + hwy::If(), float, hwy::bfloat16_t>; -template -struct Layer { - Layer() {} static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kModelDim = TConfig::kModelDim; @@ -49,111 +54,6 @@ struct Layer { static constexpr size_t kGriffinDim = TConfig::kGriffinLayers > 0 ? kModelDim : 0; - union { - struct { - std::array attn_vec_einsum_w; - std::array qkv_einsum_w; - std::array attention_output_biases; - }; - - struct { - std::array linear_x_w; - std::array linear_x_biases; - std::array linear_y_w; - std::array linear_y_biases; - std::array linear_out_w; - std::array linear_out_biases; - std::array conv_w; - std::array conv_biases; - std::array gate_w; - std::array gate_biases; - std::array a; - } griffin; - }; - - std::array gating_einsum_w; - std::array linear_w; - std::array pre_attention_norm_scale; - std::array pre_ffw_norm_scale; - std::array post_attention_norm_scale; - std::array post_ffw_norm_scale; - - std::array ffw_gating_biases; - std::array ffw_output_biases; -}; - -template -using LayerF = Layer; - -// Array instead of single large allocation for parallel mem init. Split out of -// Weights so that only these pointers are initialized. -template -struct LayerPointers { - explicit LayerPointers(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { - this->layers[task] = hwy::AllocateAligned>(1); - }); - } - - using TLayer = Layer; - std::array, TConfig::kLayers> layers; -}; - -template -struct Weights { - // No ctor/dtor, allocated via AllocateAligned. - - std::array - embedder_input_embedding; - - std::array final_norm_scale; - - LayerPointers layer_ptrs; - - std::array scales; - - const Layer* GetLayer(size_t layer) const { - return layer_ptrs.layers[layer].get(); - } - Layer* GetLayer(size_t layer) { - return layer_ptrs.layers[layer].get(); - } -}; - -template -using WeightsF = Weights; - -// ---------------------------------------------------------------------------- -// Compressed - -template -struct CompressedLayer { - // No ctor/dtor, allocated via AllocateAligned. - - using TLayer = gcpp::LayerF; - using Weight = typename TConfig::Weight; - // If weights are f32, also f32; otherwise at least bf16. Useful for ops that - // do not yet support smaller compressed types, or require at least bf16. When - // weights are f32, we also want such tensors to be f32. - using WeightF32OrBF16 = - hwy::If(), float, hwy::bfloat16_t>; - - static constexpr size_t kHeads = TLayer::kHeads; - static constexpr size_t kKVHeads = TLayer::kKVHeads; - static constexpr size_t kModelDim = TLayer::kModelDim; - static constexpr size_t kQKVDim = TLayer::kQKVDim; - static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim; - static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize; - static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize; - static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize; - static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth; - static constexpr bool kFFBiases = TLayer::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; - static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim; - static constexpr size_t kGriffinDim = TLayer::kGriffinDim; - - // Compressed Parameters - template using ArrayT = CompressedArray; @@ -171,7 +71,7 @@ struct CompressedLayer { ArrayT linear_y_biases; ArrayT linear_out_w; ArrayT linear_out_biases; - ArrayT conv_w; + ArrayT conv_w; ArrayT conv_biases; ArrayT gate_w; ArrayT gate_biases; @@ -179,7 +79,7 @@ struct CompressedLayer { } griffin; }; - ArrayT gating_einsum_w; + ArrayT gating_einsum_w; ArrayT linear_w; // We don't yet have an RMSNorm that accepts all Weight. ArrayT pre_attention_norm_scale; @@ -251,25 +151,6 @@ struct CompressedWeights { // ---------------------------------------------------------------------------- // Interface -// TODO: can we use TConfig::Weight instead of T? -template -struct AllocateWeights { - ByteStorageT operator()(hwy::ThreadPool& pool) const { - using TWeights = Weights; - ByteStorageT weights_u8 = AllocateSizeof(); - TWeights* weights = reinterpret_cast(weights_u8.get()); - new (&weights->layer_ptrs) LayerPointers(pool); - return weights_u8; - } -}; - -template -struct AllocateWeightsF { - ByteStorageT operator()(hwy::ThreadPool& pool) const { - return AllocateWeights()(pool); - } -}; - template struct AllocateCompressedWeights { ByteStorageT operator()(hwy::ThreadPool& pool) const { @@ -281,86 +162,26 @@ struct AllocateCompressedWeights { } }; -template -struct ZeroInitWeights { - void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { - Weights& w = - *reinterpret_cast*>(weights.get()); - hwy::ZeroBytes(&w.embedder_input_embedding, - sizeof(w.embedder_input_embedding)); - hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); - for (int i = 0; i < TConfig::kLayers; ++i) { - hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); - } - } -}; - -template -struct ZeroInitWeightsF { - void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { - ZeroInitWeights()(weights, pool); - } -}; - template struct ZeroInitCompressedWeights { - void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { - CompressedWeights& w = - *reinterpret_cast*>(weights.get()); - w.ZeroInit(); + void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { + CompressedWeights& weights = + *reinterpret_cast*>(weights_u8.get()); + weights.ZeroInit(); } }; -template -struct CopyWeights { -void operator()(Weights& dst, - const Weights& src) const { - hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, - sizeof(src.embedder_input_embedding)); - hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, - sizeof(src.final_norm_scale)); - for (int i = 0; i < TConfig::kLayers; ++i) { - hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), - sizeof(*dst.GetLayer(i))); - } - } -}; +// TODO: also add RandInitCompressedWeights template -struct DeleteLayersPtrs { +struct DeleteCompressedWeights { void operator()(ByteStorageT& weights_u8) const { - auto* weights = - reinterpret_cast*>(weights_u8.get()); - weights->~CompressedWeights(); + CompressedWeights& weights = + *reinterpret_cast*>(weights_u8.get()); + weights.~CompressedWeights(); } }; -// Owns weights and provides access to TConfig. -template -class WeightsWrapper { - public: - WeightsWrapper() - : pool_(0), - data_(AllocateWeights()(pool_)), - weights_(reinterpret_cast*>(data_.get())) {} - - ~WeightsWrapper() { - get().layer_ptrs.~LayerPointers(); - } - - const Weights& get() const { return *weights_; } - Weights& get() { return *weights_; } - void clear() { ZeroInitWeights()(data_, pool_); } - void copy(const WeightsWrapper& other) { - CopyWeights()(get(), other.get()); - } - - private: - hwy::ThreadPool pool_; - ByteStorageT data_; - Weights* weights_; -}; - ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, Type weight_type, hwy::ThreadPool& pool); @@ -369,30 +190,60 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); // ---------------------------------------------------------------------------- // Iterators -// Calls func(name, float*, CompressedArray&) for each tensor. float* is null -// if weights = null, which happens during the first call where we attempt to -// load from cache. -// -// This avoids repeating the list of tensors between loading and compressing. -template -void ForEachTensor(const WeightsF* weights, - CompressedWeights& c_weights, Func& func) { - func("c_embedding", - weights ? weights->embedder_input_embedding.data() : nullptr, - c_weights.embedder_input_embedding); - func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr, - c_weights.final_norm_scale); +// We rely on `if constexpr` to ensure raw_weights->member is only compiled +// when valid, i.e., kHaveRaw == true, but the IDE analysis does not understand +// this, hence hide the member access from it. +#if HWY_IDE +#define GEMMA_MEMBER(aggregate, member) nullptr +#else +#define GEMMA_MEMBER(aggregate, member) aggregate->member +#endif +// Used by ForEachTensor for tensors that are not in a layer. +#define GEMMA_CALL_TOP_FUNC(name, member) \ + { \ + const float* raw_tensor = nullptr; \ + if constexpr (kHaveRaw) { \ + raw_tensor = GEMMA_MEMBER(raw_weights, member.data()); \ + } \ + func(name, raw_tensor, c_weights.member); \ + } + +// Used by ForEachTensor for per-layer tensors. Writes into name_buf. #define GEMMA_CALL_FUNC(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member) + { \ + const float* raw_tensor = nullptr; \ + if constexpr (kHaveRaw) { \ + raw_tensor = GEMMA_MEMBER(raw_layer, member.data()); \ + } \ + func(name_buf, raw_tensor, c_layer->member); \ + } + +// Calls func(name, float*, CompressedArray&) for each tensor. float* is +// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens +// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be +// specified and we pass a float* pointing to the raw float weights for that +// tensor for use by compress_weights.cc. +// +// This avoids repeating the list of tensors between loading and compressing, +// while also avoiding dependency on raw_weights.h. +template +void ForEachTensor(const RawWeights* raw_weights, + CompressedWeights& c_weights, Func& func) { + GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale); char name_buf[16]; for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - const LayerF* layer = weights ? weights->GetLayer(idx) : nullptr; - CompressedLayer* layer_weights = c_weights.GetLayer(idx); + const RawLayer* raw_layer = nullptr; + if constexpr (kHaveRaw) { + raw_layer = raw_weights->GetLayer(idx); + } + CompressedLayer* c_layer = c_weights.GetLayer(idx); GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale); GEMMA_CALL_FUNC("gating_ein", gating_einsum_w); @@ -430,6 +281,7 @@ void ForEachTensor(const WeightsF* weights, } } #undef GEMMA_CALL_FUNC +#undef GEMMA_CALL_TOP_FUNC } // ForEachTensor #define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member) diff --git a/gemma/weights_raw.h b/gemma/weights_raw.h new file mode 100644 index 0000000..cb66876 --- /dev/null +++ b/gemma/weights_raw.h @@ -0,0 +1,242 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_RAW_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_RAW_H_ + +// NOTE: this file should only be used by compress_weights; it is currently +// also referenced by backprop, but we plan to remove that. Historical note: +// this was the original f32-only simple on-disk format created by a Python +// export script. BlobStore is now the preferred on-disk format, and we load +// that into CompressedWeights. + +#include + +#include "gemma/common.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +template +struct Layer { + Layer() {} + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; + static constexpr size_t kQKVEinsumWSize = + (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; + // 2x for (gelu gating vector, gated vector) + static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr bool kFFBiases = TConfig::kFFBiases; + static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr size_t kAOBiasDim = + TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; + + union { + struct { + std::array attn_vec_einsum_w; + std::array qkv_einsum_w; + std::array attention_output_biases; + }; + + struct { + std::array linear_x_w; + std::array linear_x_biases; + std::array linear_y_w; + std::array linear_y_biases; + std::array linear_out_w; + std::array linear_out_biases; + std::array conv_w; + std::array conv_biases; + std::array gate_w; + std::array gate_biases; + std::array a; + } griffin; + }; + + std::array gating_einsum_w; + std::array linear_w; + std::array pre_attention_norm_scale; + std::array pre_ffw_norm_scale; + std::array post_attention_norm_scale; + std::array post_ffw_norm_scale; + + std::array ffw_gating_biases; + std::array ffw_output_biases; +}; + +template +using LayerF = Layer; + +// Array instead of single large allocation for parallel mem init. Split out of +// Weights so that only these pointers are initialized. +template +struct LayerPointers { + explicit LayerPointers(hwy::ThreadPool& pool) { + pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { + this->layers[task] = hwy::AllocateAligned>(1); + }); + } + + using TLayer = Layer; + std::array, TConfig::kLayers> layers; +}; + +template +struct Weights { + // No ctor/dtor, allocated via AllocateAligned. + + std::array + embedder_input_embedding; + + std::array final_norm_scale; + + LayerPointers layer_ptrs; + + std::array scales; + + const Layer* GetLayer(size_t layer) const { + return layer_ptrs.layers[layer].get(); + } + Layer* GetLayer(size_t layer) { + return layer_ptrs.layers[layer].get(); + } +}; + +template +using WeightsF = Weights; + +// TODO: can we use TConfig::Weight instead of T? +template +struct AllocateWeights { + ByteStorageT operator()(hwy::ThreadPool& pool) const { + using TWeights = Weights; + ByteStorageT weights_u8 = AllocateSizeof(); + TWeights* weights = reinterpret_cast(weights_u8.get()); + new (&weights->layer_ptrs) LayerPointers(pool); + return weights_u8; + } +}; + +template +struct AllocateWeightsF { + ByteStorageT operator()(hwy::ThreadPool& pool) const { + return AllocateWeights()(pool); + } +}; + +// TODO: make a member of Weights. +template +struct ZeroInitWeights { + void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { + Weights& w = + *reinterpret_cast*>(weights.get()); + hwy::ZeroBytes(&w.embedder_input_embedding, + sizeof(w.embedder_input_embedding)); + hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); + } + } +}; + +template +struct ZeroInitWeightsF { + void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { + ZeroInitWeights()(weights, pool); + } +}; + +template +struct CopyWeights { +void operator()(Weights& dst, + const Weights& src) const { + hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, + sizeof(src.embedder_input_embedding)); + hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, + sizeof(src.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), + sizeof(*dst.GetLayer(i))); + } + } +}; + +template +void RandInit(std::array& x, T stddev, std::mt19937& gen) { + std::normal_distribution dist(0.0, stddev); + for (size_t i = 0; i < kLen; ++i) { + x[i] = dist(gen); + } +} + +// TODO: make a member of Layer. +template +void RandInit(Layer& w, T stddev, std::mt19937& gen) { + RandInit(w.pre_attention_norm_scale, stddev, gen); + RandInit(w.attn_vec_einsum_w, stddev, gen); + RandInit(w.qkv_einsum_w, stddev, gen); + RandInit(w.pre_ffw_norm_scale, stddev, gen); + RandInit(w.gating_einsum_w, stddev, gen); + RandInit(w.linear_w, stddev, gen); +} + +template +void RandInit(Weights& w, T stddev, std::mt19937& gen) { + static constexpr size_t kLayers = TConfig::kLayers; + RandInit(w.embedder_input_embedding, stddev, gen); + RandInit(w.final_norm_scale, stddev, gen); + for (size_t i = 0; i < kLayers; ++i) { + RandInit(*w.GetLayer(i), stddev, gen); + } +} + +// Owns weights and provides access to TConfig. +template +class WeightsWrapper { + public: + WeightsWrapper() + : pool_(0), + data_(AllocateWeights()(pool_)), + weights_(reinterpret_cast*>(data_.get())) {} + + ~WeightsWrapper() { + get().layer_ptrs.~LayerPointers(); + } + + const Weights& get() const { return *weights_; } + Weights& get() { return *weights_; } + void clear() { ZeroInitWeights()(data_, pool_); } + void copy(const WeightsWrapper& other) { + CopyWeights()(get(), other.get()); + } + + private: + hwy::ThreadPool pool_; + ByteStorageT data_; + Weights* weights_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_RAW_H_