mirror of https://github.com/google/gemma.cpp.git
Move raw_weights into separate header, used mainly by compress_weights.
Fix warnings in backprop/* (include) PiperOrigin-RevId: 643983136
This commit is contained in:
parent
ad790d89d1
commit
7d0720675f
21
BUILD.bazel
21
BUILD.bazel
|
|
@ -83,6 +83,18 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "weights_raw",
|
||||||
|
hdrs = ["gemma/weights_raw.h"],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":weights",
|
||||||
|
"//compression:compress",
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:thread_pool",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gemma_lib",
|
name = "gemma_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
@ -214,6 +226,7 @@ cc_binary(
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":weights",
|
":weights",
|
||||||
|
":weights_raw",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
|
@ -331,7 +344,7 @@ cc_library(
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":prompt",
|
":prompt",
|
||||||
":weights",
|
":weights_raw",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -346,7 +359,7 @@ cc_test(
|
||||||
":backprop_scalar",
|
":backprop_scalar",
|
||||||
":prompt",
|
":prompt",
|
||||||
":sampler",
|
":sampler",
|
||||||
":weights",
|
":weights_raw",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -363,11 +376,9 @@ cc_test(
|
||||||
":backprop_scalar",
|
":backprop_scalar",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
":prompt",
|
|
||||||
":sampler",
|
":sampler",
|
||||||
":weights",
|
":weights_raw",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:compress",
|
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,7 @@ set(SOURCES
|
||||||
gemma/ops.h
|
gemma/ops.h
|
||||||
gemma/weights.cc
|
gemma/weights.cc
|
||||||
gemma/weights.h
|
gemma/weights.h
|
||||||
|
gemma/weights_raw.h
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,6 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
|
||||||
|
|
@ -20,14 +20,13 @@
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <complex>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "backprop/common_scalar.h"
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h" // EmbeddingScaling
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,9 @@
|
||||||
|
|
||||||
#include "backprop/backward_scalar.h"
|
#include "backprop/backward_scalar.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <string.h> // memset
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
@ -23,6 +26,7 @@
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
@ -29,10 +28,8 @@
|
||||||
#include "backprop/forward_scalar.h"
|
#include "backprop/forward_scalar.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "compression/compress.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/weights_raw.h"
|
||||||
#include "gemma/weights.h"
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -52,8 +49,6 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
void TestMatMulVJP() {
|
void TestMatMulVJP() {
|
||||||
static const size_t kRows = 8;
|
static const size_t kRows = 8;
|
||||||
static const size_t kCols = 64;
|
static const size_t kCols = 64;
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,8 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
|
@ -40,11 +40,11 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "gemma/ops.h"
|
#include "gemma/ops.h"
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
template <typename ArrayT>
|
template <typename ArrayT>
|
||||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
||||||
|
|
@ -202,11 +202,10 @@ void ApplyForwardLayer(const LayerT<TConfig>& weights,
|
||||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
using VF = hn::Vec<DF>;
|
|
||||||
DF df;
|
DF df;
|
||||||
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
||||||
const auto y = Load(df, out + i);
|
const auto y = hn::Load(df, out + i);
|
||||||
const auto x = Load(df, out_mul + i);
|
const auto x = hn::Load(df, out_mul + i);
|
||||||
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
|
hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h" // EmbeddingScaling
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,43 +16,16 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <random>
|
|
||||||
|
|
||||||
#include "gemma/weights.h"
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
#include "gemma/weights_raw.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template<typename T, size_t kLen>
|
|
||||||
void RandInit(std::array<T, kLen>& x, T stddev, std::mt19937& gen) {
|
|
||||||
std::normal_distribution<T> dist(0.0, stddev);
|
|
||||||
for (size_t i = 0; i < kLen; ++i) {
|
|
||||||
x[i] = dist(gen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename TConfig>
|
|
||||||
void RandInit(Layer<T, TConfig>& w, T stddev, std::mt19937& gen) {
|
|
||||||
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
|
||||||
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
|
||||||
RandInit(w.qkv_einsum_w, stddev, gen);
|
|
||||||
RandInit(w.pre_ffw_norm_scale, stddev, gen);
|
|
||||||
RandInit(w.gating_einsum_w, stddev, gen);
|
|
||||||
RandInit(w.linear_w, stddev, gen);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename TConfig>
|
|
||||||
void RandInit(Weights<T, TConfig>& w, T stddev, std::mt19937& gen) {
|
|
||||||
static constexpr size_t kLayers = TConfig::kLayers;
|
|
||||||
RandInit(w.embedder_input_embedding, stddev, gen);
|
|
||||||
RandInit(w.final_norm_scale, stddev, gen);
|
|
||||||
for (size_t i = 0; i < kLayers; ++i) {
|
|
||||||
RandInit(*w.GetLayer(i), stddev, gen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T, typename U, size_t kLen>
|
template<typename T, typename U, size_t kLen>
|
||||||
void Complexify(const std::array<T, kLen>& x,
|
void Complexify(const std::array<T, kLen>& x,
|
||||||
std::array<std::complex<U>, kLen>& c_x) {
|
std::array<std::complex<U>, kLen>& c_x) {
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "gemma/common.h" // Model
|
#include "gemma/common.h" // Model
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "gemma/weights_raw.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -317,7 +318,8 @@ void CompressWeights(const Path& weights_path,
|
||||||
WeightsF<TConfig>* weights =
|
WeightsF<TConfig>* weights =
|
||||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
Compressor compressor(pool);
|
Compressor compressor(pool);
|
||||||
ForEachTensor<TConfig>(weights, *c_weights, compressor);
|
ForEachTensor</*kHaveRaw=*/true, TConfig, LayerF<TConfig>>(
|
||||||
|
weights, *c_weights, compressor);
|
||||||
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||||
compressor.WriteAll(pool, compressed_weights_path);
|
compressor.WriteAll(pool, compressed_weights_path);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,6 @@ constexpr bool kShowTokenization = false;
|
||||||
// Must be aligned.
|
// Must be aligned.
|
||||||
template <class TConfig, size_t kBatchSize>
|
template <class TConfig, size_t kBatchSize>
|
||||||
struct Activations {
|
struct Activations {
|
||||||
using LayerConfig = LayerF<TConfig>;
|
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
|
@ -979,7 +978,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
CallForModelAndWeight<DeleteLayersPtrs>(model_type_, weight_type_,
|
CallForModelAndWeight<DeleteCompressedWeights>(model_type_, weight_type_,
|
||||||
weights_u8_);
|
weights_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
|
@ -47,7 +46,8 @@ struct LoadCompressedWeightsT {
|
||||||
|
|
||||||
std::array<float, TConfig::kNumTensorScales> scales;
|
std::array<float, TConfig::kNumTensorScales> scales;
|
||||||
CacheLoader loader(weights);
|
CacheLoader loader(weights);
|
||||||
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
const void* raw_weights = nullptr; // ForEachTensor requires const.
|
||||||
|
ForEachTensor</*kHaveRaw=*/false, TConfig>(raw_weights, *c_weights, loader);
|
||||||
loader.LoadScales(scales.data(), scales.size());
|
loader.LoadScales(scales.data(), scales.size());
|
||||||
if (!loader.ReadAll(pool)) {
|
if (!loader.ReadAll(pool)) {
|
||||||
HWY_ABORT("Failed to load model weights.");
|
HWY_ABORT("Failed to load model weights.");
|
||||||
|
|
|
||||||
284
gemma/weights.h
284
gemma/weights.h
|
|
@ -25,12 +25,17 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
template <class TConfig>
|
||||||
// Uncompressed
|
struct CompressedLayer {
|
||||||
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
|
||||||
|
using Weight = typename TConfig::Weight;
|
||||||
|
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
||||||
|
// do not yet support smaller compressed types, or require at least bf16. When
|
||||||
|
// weights are f32, we also want such tensors to be f32.
|
||||||
|
using WeightF32OrBF16 =
|
||||||
|
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
|
||||||
|
|
||||||
template <typename T, class TConfig>
|
|
||||||
struct Layer {
|
|
||||||
Layer() {}
|
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -49,111 +54,6 @@ struct Layer {
|
||||||
static constexpr size_t kGriffinDim =
|
static constexpr size_t kGriffinDim =
|
||||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||||
|
|
||||||
union {
|
|
||||||
struct {
|
|
||||||
std::array<T, kAttVecEinsumWSize> attn_vec_einsum_w;
|
|
||||||
std::array<T, kQKVEinsumWSize> qkv_einsum_w;
|
|
||||||
std::array<T, kAOBiasDim> attention_output_biases;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct {
|
|
||||||
std::array<T, kGriffinDim * kGriffinDim> linear_x_w;
|
|
||||||
std::array<T, kGriffinDim> linear_x_biases;
|
|
||||||
std::array<T, kGriffinDim * kGriffinDim> linear_y_w;
|
|
||||||
std::array<T, kGriffinDim> linear_y_biases;
|
|
||||||
std::array<T, kGriffinDim * kGriffinDim> linear_out_w;
|
|
||||||
std::array<T, kGriffinDim> linear_out_biases;
|
|
||||||
std::array<T, kConv1dWidth * kGriffinDim> conv_w;
|
|
||||||
std::array<T, kGriffinDim> conv_biases;
|
|
||||||
std::array<T, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
|
||||||
std::array<T, kGriffinDim * 2> gate_biases;
|
|
||||||
std::array<T, kGriffinDim> a;
|
|
||||||
} griffin;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::array<T, kGatingEinsumWSize> gating_einsum_w;
|
|
||||||
std::array<T, kModelDim * kFFHiddenDim> linear_w;
|
|
||||||
std::array<T, kModelDim> pre_attention_norm_scale;
|
|
||||||
std::array<T, kModelDim> pre_ffw_norm_scale;
|
|
||||||
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
|
|
||||||
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
|
||||||
|
|
||||||
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
|
||||||
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class TConfig>
|
|
||||||
using LayerF = Layer<float, TConfig>;
|
|
||||||
|
|
||||||
// Array instead of single large allocation for parallel mem init. Split out of
|
|
||||||
// Weights so that only these pointers are initialized.
|
|
||||||
template <typename T, class TConfig>
|
|
||||||
struct LayerPointers {
|
|
||||||
explicit LayerPointers(hwy::ThreadPool& pool) {
|
|
||||||
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
|
||||||
this->layers[task] = hwy::AllocateAligned<Layer<T, TConfig>>(1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
using TLayer = Layer<T, TConfig>;
|
|
||||||
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, class TConfig>
|
|
||||||
struct Weights {
|
|
||||||
// No ctor/dtor, allocated via AllocateAligned.
|
|
||||||
|
|
||||||
std::array<T, TConfig::kVocabSize * TConfig::kModelDim>
|
|
||||||
embedder_input_embedding;
|
|
||||||
|
|
||||||
std::array<T, TConfig::kModelDim> final_norm_scale;
|
|
||||||
|
|
||||||
LayerPointers<T, TConfig> layer_ptrs;
|
|
||||||
|
|
||||||
std::array<T, TConfig::kNumTensorScales> scales;
|
|
||||||
|
|
||||||
const Layer<T, TConfig>* GetLayer(size_t layer) const {
|
|
||||||
return layer_ptrs.layers[layer].get();
|
|
||||||
}
|
|
||||||
Layer<T, TConfig>* GetLayer(size_t layer) {
|
|
||||||
return layer_ptrs.layers[layer].get();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class TConfig>
|
|
||||||
using WeightsF = Weights<float, TConfig>;
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
|
||||||
// Compressed
|
|
||||||
|
|
||||||
template <class TConfig>
|
|
||||||
struct CompressedLayer {
|
|
||||||
// No ctor/dtor, allocated via AllocateAligned.
|
|
||||||
|
|
||||||
using TLayer = gcpp::LayerF<TConfig>;
|
|
||||||
using Weight = typename TConfig::Weight;
|
|
||||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
|
||||||
// do not yet support smaller compressed types, or require at least bf16. When
|
|
||||||
// weights are f32, we also want such tensors to be f32.
|
|
||||||
using WeightF32OrBF16 =
|
|
||||||
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
|
|
||||||
|
|
||||||
static constexpr size_t kHeads = TLayer::kHeads;
|
|
||||||
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
|
||||||
static constexpr size_t kModelDim = TLayer::kModelDim;
|
|
||||||
static constexpr size_t kQKVDim = TLayer::kQKVDim;
|
|
||||||
static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim;
|
|
||||||
static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize;
|
|
||||||
static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize;
|
|
||||||
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
|
|
||||||
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
|
|
||||||
static constexpr bool kFFBiases = TLayer::kFFBiases;
|
|
||||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
|
||||||
static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim;
|
|
||||||
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
|
|
||||||
|
|
||||||
// Compressed Parameters
|
|
||||||
|
|
||||||
template <class T, size_t N>
|
template <class T, size_t N>
|
||||||
using ArrayT = CompressedArray<T, N>;
|
using ArrayT = CompressedArray<T, N>;
|
||||||
|
|
||||||
|
|
@ -171,7 +71,7 @@ struct CompressedLayer {
|
||||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||||
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_out_w;
|
ArrayT<Weight, kGriffinDim * kGriffinDim> linear_out_w;
|
||||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||||
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
|
ArrayT<float, kConv1dWidth * kGriffinDim> conv_w;
|
||||||
ArrayT<float, kGriffinDim> conv_biases;
|
ArrayT<float, kGriffinDim> conv_biases;
|
||||||
ArrayT<Weight, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
ArrayT<Weight, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||||
|
|
@ -179,7 +79,7 @@ struct CompressedLayer {
|
||||||
} griffin;
|
} griffin;
|
||||||
};
|
};
|
||||||
|
|
||||||
ArrayT<Weight, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
ArrayT<Weight, kGatingEinsumWSize> gating_einsum_w;
|
||||||
ArrayT<Weight, kModelDim * kFFHiddenDim> linear_w;
|
ArrayT<Weight, kModelDim * kFFHiddenDim> linear_w;
|
||||||
// We don't yet have an RMSNorm that accepts all Weight.
|
// We don't yet have an RMSNorm that accepts all Weight.
|
||||||
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
|
||||||
|
|
@ -251,25 +151,6 @@ struct CompressedWeights {
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Interface
|
// Interface
|
||||||
|
|
||||||
// TODO: can we use TConfig::Weight instead of T?
|
|
||||||
template <typename T, typename TConfig>
|
|
||||||
struct AllocateWeights {
|
|
||||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
|
||||||
using TWeights = Weights<T, TConfig>;
|
|
||||||
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
|
||||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
|
||||||
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
|
|
||||||
return weights_u8;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TConfig>
|
|
||||||
struct AllocateWeightsF {
|
|
||||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
|
||||||
return AllocateWeights<float, TConfig>()(pool);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct AllocateCompressedWeights {
|
struct AllocateCompressedWeights {
|
||||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
|
|
@ -281,86 +162,26 @@ struct AllocateCompressedWeights {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename TConfig>
|
|
||||||
struct ZeroInitWeights {
|
|
||||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
|
||||||
Weights<T, TConfig>& w =
|
|
||||||
*reinterpret_cast<Weights<T, TConfig>*>(weights.get());
|
|
||||||
hwy::ZeroBytes(&w.embedder_input_embedding,
|
|
||||||
sizeof(w.embedder_input_embedding));
|
|
||||||
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
|
||||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
|
||||||
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TConfig>
|
|
||||||
struct ZeroInitWeightsF {
|
|
||||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
|
||||||
ZeroInitWeights<float, TConfig>()(weights, pool);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct ZeroInitCompressedWeights {
|
struct ZeroInitCompressedWeights {
|
||||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
||||||
CompressedWeights<TConfig>& w =
|
CompressedWeights<TConfig>& weights =
|
||||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights.get());
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
w.ZeroInit();
|
weights.ZeroInit();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename TConfig>
|
// TODO: also add RandInitCompressedWeights
|
||||||
struct CopyWeights {
|
|
||||||
void operator()(Weights<T, TConfig>& dst,
|
|
||||||
const Weights<T, TConfig>& src) const {
|
|
||||||
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
|
||||||
sizeof(src.embedder_input_embedding));
|
|
||||||
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
|
||||||
sizeof(src.final_norm_scale));
|
|
||||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
|
||||||
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i),
|
|
||||||
sizeof(*dst.GetLayer(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct DeleteLayersPtrs {
|
struct DeleteCompressedWeights {
|
||||||
void operator()(ByteStorageT& weights_u8) const {
|
void operator()(ByteStorageT& weights_u8) const {
|
||||||
auto* weights =
|
CompressedWeights<TConfig>& weights =
|
||||||
reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
weights->~CompressedWeights<TConfig>();
|
weights.~CompressedWeights<TConfig>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Owns weights and provides access to TConfig.
|
|
||||||
template <typename T, typename TConfig>
|
|
||||||
class WeightsWrapper {
|
|
||||||
public:
|
|
||||||
WeightsWrapper()
|
|
||||||
: pool_(0),
|
|
||||||
data_(AllocateWeights<T, TConfig>()(pool_)),
|
|
||||||
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
|
||||||
|
|
||||||
~WeightsWrapper() {
|
|
||||||
get().layer_ptrs.~LayerPointers<T, TConfig>();
|
|
||||||
}
|
|
||||||
|
|
||||||
const Weights<T, TConfig>& get() const { return *weights_; }
|
|
||||||
Weights<T, TConfig>& get() { return *weights_; }
|
|
||||||
void clear() { ZeroInitWeights<T, TConfig>()(data_, pool_); }
|
|
||||||
void copy(const WeightsWrapper<T, TConfig>& other) {
|
|
||||||
CopyWeights<T, TConfig>()(get(), other.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
hwy::ThreadPool pool_;
|
|
||||||
ByteStorageT data_;
|
|
||||||
Weights<T, TConfig>* weights_;
|
|
||||||
};
|
|
||||||
|
|
||||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||||
Type weight_type, hwy::ThreadPool& pool);
|
Type weight_type, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
|
|
@ -369,30 +190,60 @@ void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Iterators
|
// Iterators
|
||||||
|
|
||||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
// We rely on `if constexpr` to ensure raw_weights->member is only compiled
|
||||||
// if weights = null, which happens during the first call where we attempt to
|
// when valid, i.e., kHaveRaw == true, but the IDE analysis does not understand
|
||||||
// load from cache.
|
// this, hence hide the member access from it.
|
||||||
//
|
#if HWY_IDE
|
||||||
// This avoids repeating the list of tensors between loading and compressing.
|
#define GEMMA_MEMBER(aggregate, member) nullptr
|
||||||
template <class TConfig, class Func>
|
#else
|
||||||
void ForEachTensor(const WeightsF<TConfig>* weights,
|
#define GEMMA_MEMBER(aggregate, member) aggregate->member
|
||||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
#endif
|
||||||
func("c_embedding",
|
|
||||||
weights ? weights->embedder_input_embedding.data() : nullptr,
|
|
||||||
c_weights.embedder_input_embedding);
|
|
||||||
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
|
|
||||||
c_weights.final_norm_scale);
|
|
||||||
|
|
||||||
|
// Used by ForEachTensor for tensors that are not in a layer.
|
||||||
|
#define GEMMA_CALL_TOP_FUNC(name, member) \
|
||||||
|
{ \
|
||||||
|
const float* raw_tensor = nullptr; \
|
||||||
|
if constexpr (kHaveRaw) { \
|
||||||
|
raw_tensor = GEMMA_MEMBER(raw_weights, member.data()); \
|
||||||
|
} \
|
||||||
|
func(name, raw_tensor, c_weights.member); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by ForEachTensor for per-layer tensors. Writes into name_buf.
|
||||||
#define GEMMA_CALL_FUNC(name, member) \
|
#define GEMMA_CALL_FUNC(name, member) \
|
||||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
|
{ \
|
||||||
|
const float* raw_tensor = nullptr; \
|
||||||
|
if constexpr (kHaveRaw) { \
|
||||||
|
raw_tensor = GEMMA_MEMBER(raw_layer, member.data()); \
|
||||||
|
} \
|
||||||
|
func(name_buf, raw_tensor, c_layer->member); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calls func(name, float*, CompressedArray&) for each tensor. float* is
|
||||||
|
// null if !kHaveRaw, in which case raw_weights can be nullptr. This happens
|
||||||
|
// when loading weights from BlobStore. If kHaveRaw, then RawLayer must be
|
||||||
|
// specified and we pass a float* pointing to the raw float weights for that
|
||||||
|
// tensor for use by compress_weights.cc.
|
||||||
|
//
|
||||||
|
// This avoids repeating the list of tensors between loading and compressing,
|
||||||
|
// while also avoiding dependency on raw_weights.h.
|
||||||
|
template <bool kHaveRaw, class TConfig, class RawLayer = void,
|
||||||
|
class RawWeights = void, class Func>
|
||||||
|
void ForEachTensor(const RawWeights* raw_weights,
|
||||||
|
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||||
|
GEMMA_CALL_TOP_FUNC("c_embedding", embedder_input_embedding);
|
||||||
|
GEMMA_CALL_TOP_FUNC("c_final_norm", final_norm_scale);
|
||||||
|
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const LayerF<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
const RawLayer* raw_layer = nullptr;
|
||||||
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
if constexpr (kHaveRaw) {
|
||||||
|
raw_layer = raw_weights->GetLayer(idx);
|
||||||
|
}
|
||||||
|
CompressedLayer<TConfig>* c_layer = c_weights.GetLayer(idx);
|
||||||
|
|
||||||
GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
||||||
GEMMA_CALL_FUNC("gating_ein", gating_einsum_w);
|
GEMMA_CALL_FUNC("gating_ein", gating_einsum_w);
|
||||||
|
|
@ -430,6 +281,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#undef GEMMA_CALL_FUNC
|
#undef GEMMA_CALL_FUNC
|
||||||
|
#undef GEMMA_CALL_TOP_FUNC
|
||||||
} // ForEachTensor
|
} // ForEachTensor
|
||||||
|
|
||||||
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
|
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_RAW_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_RAW_H_
|
||||||
|
|
||||||
|
// NOTE: this file should only be used by compress_weights; it is currently
|
||||||
|
// also referenced by backprop, but we plan to remove that. Historical note:
|
||||||
|
// this was the original f32-only simple on-disk format created by a Python
|
||||||
|
// export script. BlobStore is now the preferred on-disk format, and we load
|
||||||
|
// that into CompressedWeights.
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
template <typename T, class TConfig>
|
||||||
|
struct Layer {
|
||||||
|
Layer() {}
|
||||||
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||||
|
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||||
|
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
||||||
|
static constexpr size_t kQKVEinsumWSize =
|
||||||
|
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||||
|
// 2x for (gelu gating vector, gated vector)
|
||||||
|
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||||
|
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||||
|
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||||
|
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||||
|
static constexpr size_t kAOBiasDim =
|
||||||
|
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||||
|
static constexpr size_t kGriffinDim =
|
||||||
|
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||||
|
|
||||||
|
union {
|
||||||
|
struct {
|
||||||
|
std::array<T, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||||
|
std::array<T, kQKVEinsumWSize> qkv_einsum_w;
|
||||||
|
std::array<T, kAOBiasDim> attention_output_biases;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct {
|
||||||
|
std::array<T, kGriffinDim * kGriffinDim> linear_x_w;
|
||||||
|
std::array<T, kGriffinDim> linear_x_biases;
|
||||||
|
std::array<T, kGriffinDim * kGriffinDim> linear_y_w;
|
||||||
|
std::array<T, kGriffinDim> linear_y_biases;
|
||||||
|
std::array<T, kGriffinDim * kGriffinDim> linear_out_w;
|
||||||
|
std::array<T, kGriffinDim> linear_out_biases;
|
||||||
|
std::array<T, kConv1dWidth * kGriffinDim> conv_w;
|
||||||
|
std::array<T, kGriffinDim> conv_biases;
|
||||||
|
std::array<T, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||||
|
std::array<T, kGriffinDim * 2> gate_biases;
|
||||||
|
std::array<T, kGriffinDim> a;
|
||||||
|
} griffin;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::array<T, kGatingEinsumWSize> gating_einsum_w;
|
||||||
|
std::array<T, kModelDim * kFFHiddenDim> linear_w;
|
||||||
|
std::array<T, kModelDim> pre_attention_norm_scale;
|
||||||
|
std::array<T, kModelDim> pre_ffw_norm_scale;
|
||||||
|
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
|
||||||
|
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||||
|
|
||||||
|
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||||
|
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
using LayerF = Layer<float, TConfig>;
|
||||||
|
|
||||||
|
// Array instead of single large allocation for parallel mem init. Split out of
|
||||||
|
// Weights so that only these pointers are initialized.
|
||||||
|
template <typename T, class TConfig>
|
||||||
|
struct LayerPointers {
|
||||||
|
explicit LayerPointers(hwy::ThreadPool& pool) {
|
||||||
|
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
||||||
|
this->layers[task] = hwy::AllocateAligned<Layer<T, TConfig>>(1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
using TLayer = Layer<T, TConfig>;
|
||||||
|
std::array<hwy::AlignedFreeUniquePtr<TLayer[]>, TConfig::kLayers> layers;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, class TConfig>
|
||||||
|
struct Weights {
|
||||||
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
|
||||||
|
std::array<T, TConfig::kVocabSize * TConfig::kModelDim>
|
||||||
|
embedder_input_embedding;
|
||||||
|
|
||||||
|
std::array<T, TConfig::kModelDim> final_norm_scale;
|
||||||
|
|
||||||
|
LayerPointers<T, TConfig> layer_ptrs;
|
||||||
|
|
||||||
|
std::array<T, TConfig::kNumTensorScales> scales;
|
||||||
|
|
||||||
|
const Layer<T, TConfig>* GetLayer(size_t layer) const {
|
||||||
|
return layer_ptrs.layers[layer].get();
|
||||||
|
}
|
||||||
|
Layer<T, TConfig>* GetLayer(size_t layer) {
|
||||||
|
return layer_ptrs.layers[layer].get();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class TConfig>
|
||||||
|
using WeightsF = Weights<float, TConfig>;
|
||||||
|
|
||||||
|
// TODO: can we use TConfig::Weight instead of T?
|
||||||
|
template <typename T, typename TConfig>
|
||||||
|
struct AllocateWeights {
|
||||||
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
|
using TWeights = Weights<T, TConfig>;
|
||||||
|
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
||||||
|
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||||
|
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
|
||||||
|
return weights_u8;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
struct AllocateWeightsF {
|
||||||
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
|
return AllocateWeights<float, TConfig>()(pool);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: make a member of Weights<T>.
|
||||||
|
template <typename T, typename TConfig>
|
||||||
|
struct ZeroInitWeights {
|
||||||
|
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||||
|
Weights<T, TConfig>& w =
|
||||||
|
*reinterpret_cast<Weights<T, TConfig>*>(weights.get());
|
||||||
|
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||||
|
sizeof(w.embedder_input_embedding));
|
||||||
|
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||||
|
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||||
|
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
struct ZeroInitWeightsF {
|
||||||
|
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||||
|
ZeroInitWeights<float, TConfig>()(weights, pool);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename TConfig>
|
||||||
|
struct CopyWeights {
|
||||||
|
void operator()(Weights<T, TConfig>& dst,
|
||||||
|
const Weights<T, TConfig>& src) const {
|
||||||
|
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||||
|
sizeof(src.embedder_input_embedding));
|
||||||
|
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||||
|
sizeof(src.final_norm_scale));
|
||||||
|
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||||
|
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i),
|
||||||
|
sizeof(*dst.GetLayer(i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, size_t kLen>
|
||||||
|
void RandInit(std::array<T, kLen>& x, T stddev, std::mt19937& gen) {
|
||||||
|
std::normal_distribution<T> dist(0.0, stddev);
|
||||||
|
for (size_t i = 0; i < kLen; ++i) {
|
||||||
|
x[i] = dist(gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: make a member of Layer<T>.
|
||||||
|
template <typename T, typename TConfig>
|
||||||
|
void RandInit(Layer<T, TConfig>& w, T stddev, std::mt19937& gen) {
|
||||||
|
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
||||||
|
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
||||||
|
RandInit(w.qkv_einsum_w, stddev, gen);
|
||||||
|
RandInit(w.pre_ffw_norm_scale, stddev, gen);
|
||||||
|
RandInit(w.gating_einsum_w, stddev, gen);
|
||||||
|
RandInit(w.linear_w, stddev, gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename TConfig>
|
||||||
|
void RandInit(Weights<T, TConfig>& w, T stddev, std::mt19937& gen) {
|
||||||
|
static constexpr size_t kLayers = TConfig::kLayers;
|
||||||
|
RandInit(w.embedder_input_embedding, stddev, gen);
|
||||||
|
RandInit(w.final_norm_scale, stddev, gen);
|
||||||
|
for (size_t i = 0; i < kLayers; ++i) {
|
||||||
|
RandInit(*w.GetLayer(i), stddev, gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Owns weights and provides access to TConfig.
|
||||||
|
template <typename T, typename TConfig>
|
||||||
|
class WeightsWrapper {
|
||||||
|
public:
|
||||||
|
WeightsWrapper()
|
||||||
|
: pool_(0),
|
||||||
|
data_(AllocateWeights<T, TConfig>()(pool_)),
|
||||||
|
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
||||||
|
|
||||||
|
~WeightsWrapper() {
|
||||||
|
get().layer_ptrs.~LayerPointers<T, TConfig>();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Weights<T, TConfig>& get() const { return *weights_; }
|
||||||
|
Weights<T, TConfig>& get() { return *weights_; }
|
||||||
|
void clear() { ZeroInitWeights<T, TConfig>()(data_, pool_); }
|
||||||
|
void copy(const WeightsWrapper<T, TConfig>& other) {
|
||||||
|
CopyWeights<T, TConfig>()(get(), other.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
hwy::ThreadPool pool_;
|
||||||
|
ByteStorageT data_;
|
||||||
|
Weights<T, TConfig>* weights_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_RAW_H_
|
||||||
Loading…
Reference in New Issue