Merge pull request #224 from szabadka:cleanup

PiperOrigin-RevId: 641922102
This commit is contained in:
Copybara-Service 2024-06-10 09:11:13 -07:00
commit 49d814b519
10 changed files with 311 additions and 294 deletions

View File

@ -28,7 +28,6 @@
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -51,12 +50,12 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
template <size_t kCols, size_t kRows> template <size_t kCols, size_t kRows>
void MatMulVJP(const std::array<float, kRows * kCols>& weights, void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
const float* HWY_RESTRICT x, // num_tokens * kCols const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens, size_t num_tokens,
std::array<float, kRows * kCols>& grad_w, float* HWY_RESTRICT grad_w, // kRows * kCols,
float* HWY_RESTRICT grad_x, // num_tokens * kCols float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -72,12 +71,12 @@ void MatMulVJP(const std::array<float, kRows * kCols>& weights,
template <size_t kHeads, size_t kCols, size_t kRows> template <size_t kHeads, size_t kCols, size_t kRows>
void MultiHeadMatMulVJP( void MultiHeadMatMulVJP(
const std::array<float, kHeads * kRows * kCols>& weights, const float* HWY_RESTRICT weights, // kHeads * kRows * kCols
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens, size_t num_tokens,
std::array<float, kHeads * kRows * kCols>& grad_w, float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -166,12 +165,12 @@ static HWY_NOINLINE void InputEmbeddingVJP(
} }
} }
template <typename TConfig> template <typename TConfig, template<typename> typename LayerT>
void LayerVJP(const Layer<float, TConfig>& weights, void LayerVJP(const LayerT<TConfig>& weights,
const ForwardLayer<float, TConfig>& forward, const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad, const float* HWY_RESTRICT next_layer_grad,
size_t num_tokens, size_t num_tokens,
Layer<float, TConfig>& grad, LayerT<TConfig>& grad,
ForwardLayer<float, TConfig>& backward, ForwardLayer<float, TConfig>& backward,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
@ -184,8 +183,8 @@ void LayerVJP(const Layer<float, TConfig>& weights,
HWY_ASSERT(num_tokens <= kSeqLen); HWY_ASSERT(num_tokens <= kSeqLen);
MatMulVJP<kFFHiddenDim, kModelDim>( MatMulVJP<kFFHiddenDim, kModelDim>(
weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad, weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad,
num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(), num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
pool); pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -210,9 +209,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
} }
MatMulVJP<kModelDim, kFFHiddenDim * 2>( MatMulVJP<kModelDim, kFFHiddenDim * 2>(
weights.gating_einsum_w, weights.gating_einsum_w.data(),
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
num_tokens, grad.gating_einsum_w, num_tokens, grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), pool); backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(), RMSNormVJP(weights.pre_ffw_norm_scale.data(),
forward.attention_out.data(), forward.attention_out.data(),
@ -230,9 +229,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0])); num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>( MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
weights.attn_vec_einsum_w, forward.att_out.data(), weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(), num_tokens, backward.attention_out.data(), num_tokens,
grad.attn_vec_einsum_w, backward.att_out.data(), pool); grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool);
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -293,9 +292,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
} }
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>( MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
weights.qkv_einsum_w, forward.pre_att_rms_out.data(), weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), num_tokens, backward.qkv.data(), num_tokens,
grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool); grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(), RMSNormVJP(weights.pre_attention_norm_scale.data(),
forward.input.data(), forward.input.data(),
backward.pre_att_rms_out.data(), backward.pre_att_rms_out.data(),
@ -345,11 +344,12 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
} }
} }
template <typename TConfig> template <typename TConfig, template<typename> typename WeightsT,
template<typename> typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt, void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<float, TConfig>& weights, const WeightsT<TConfig>& weights,
const ForwardPass<float, TConfig>& forward, const ForwardPass<float, TConfig>& forward,
Weights<float, TConfig>& grad, WeightsT<TConfig>& grad,
ForwardPass<float, TConfig>& backward, ForwardPass<float, TConfig>& backward,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kVocabSize = TConfig::kVocabSize;
@ -379,9 +379,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
} }
MatMulVJP<kModelDim, kVocabSize>( MatMulVJP<kModelDim, kVocabSize>(
weights.embedder_input_embedding, forward.final_norm_output.data(), weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
backward.logits.data(), num_tokens, backward.logits.data(), num_tokens,
grad.embedder_input_embedding, backward.final_norm_output.data(), grad.embedder_input_embedding.data(), backward.final_norm_output.data(),
pool); pool);
RMSNormVJP(weights.final_norm_scale.data(), RMSNormVJP(weights.final_norm_scale.data(),
@ -398,8 +398,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
float* next_layer_grad = layer + 1 < kLayers float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data() ? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data(); : backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, LayerVJP<TConfig, LayerT>(
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool); *weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
} }
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,

View File

@ -42,13 +42,14 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
ByteStorageT& grad_u8, ByteStorageT& grad_u8,
ByteStorageT& backward_u8, ByteStorageT& backward_u8,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
using TWeights = WeightsF<TConfig>; using TWeights = CompressedWeights<TConfig>;
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get()); const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get()); auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
using TAct = ForwardPass<float, TConfig>; using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get()); const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get()); auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool); CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
prompt, weights, forward, grad, backward, pool);
} }
void CrossEntropyLossBackwardPassT(Model model, void CrossEntropyLossBackwardPassT(Model model,

View File

@ -84,8 +84,8 @@ void TestMatMulVJP() {
}; };
hwy::ZeroBytes(&grad, sizeof(grad)); hwy::ZeroBytes(&grad, sizeof(grad));
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens, MatMulVJP<kCols, kRows>(weights.data(), x.data(), dy.data(), kTokens,
grad, dx.data(), pool); grad.data(), dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
@ -130,7 +130,8 @@ void TestMultiHeadMatMulVJP() {
hwy::ZeroBytes(&grad, sizeof(grad)); hwy::ZeroBytes(&grad, sizeof(grad));
MultiHeadMatMulVJP<kHeads, kCols, kRows>( MultiHeadMatMulVJP<kHeads, kCols, kRows>(
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(),
pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
@ -235,7 +236,7 @@ void TestEndToEnd() {
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.clear(); grad.clear();
CrossEntropyLossBackwardPass( CrossEntropyLossBackwardPass<TestConfig, WeightsF, LayerF>(
prompt, weights.get(), forward1.get(), grad.get(), backward.get(), prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
pool); pool);

View File

@ -41,11 +41,12 @@ float CrossEntropyLossForwardPass(const Prompt& prompt,
ByteStorageT& forward_u8, ByteStorageT& forward_u8,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
const auto& weights = const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); *reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
auto& forward = auto& forward =
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get()); *reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
return CrossEntropyLossForwardPass<TConfig, WeightsF, LayerF>( return
prompt.tokens, prompt.context_size, weights, forward, pool); CrossEntropyLossForwardPass<TConfig, CompressedWeights, CompressedLayer>(
prompt.tokens, prompt.context_size, weights, forward, pool);
} }
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,

View File

@ -34,19 +34,17 @@
namespace gcpp { namespace gcpp {
TEST(OptimizeTest, GradientDescent) { TEST(OptimizeTest, GradientDescent) {
if (kWeightsAreCompressed) return;
hwy::ThreadPool pool(0); hwy::ThreadPool pool(0);
std::mt19937 gen(42); std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY; Model model_type = Model::GEMMA_TINY;
Type weight_type = Type::kF32; Type weight_type = Type::kF32;
ByteStorageT grad = ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool); model_type, weight_type, pool);
ByteStorageT grad_m = ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool); model_type, weight_type, pool);
ByteStorageT grad_v = ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool); model_type, weight_type, pool);
ByteStorageT forward = ByteStorageT forward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type); CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
ByteStorageT backward = ByteStorageT backward =
@ -88,10 +86,10 @@ TEST(OptimizeTest, GradientDescent) {
}; };
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen); RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m, CallForModelAndWeight<ZeroInitCompressedWeights>(
pool); model_type, weight_type, grad_m, pool);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v, CallForModelAndWeight<ZeroInitCompressedWeights>(
pool); model_type, weight_type, grad_v, pool);
printf("Initial weights:\n"); printf("Initial weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights()); LogWeightStats(model_type, weight_type, gemma.Weights());
@ -109,8 +107,8 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok; size_t num_ok;
for (; steps < 1000000; ++steps) { for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42); std::mt19937 sgen(42);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad, CallForModelAndWeight<ZeroInitCompressedWeights>(
pool); model_type, weight_type, grad, pool);
float total_loss = 0.0f; float total_loss = 0.0f;
num_ok = 0; num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) { for (size_t i = 0; i < kBatchSize; ++i) {
@ -139,7 +137,7 @@ TEST(OptimizeTest, GradientDescent) {
printf("Num steps: %zu\n", steps); printf("Num steps: %zu\n", steps);
printf("Final weights:\n"); printf("Final weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights()); LogWeightStats(model_type, weight_type, gemma.Weights());
EXPECT_LT(steps, 200); EXPECT_LT(steps, 300);
EXPECT_EQ(num_ok, kBatchSize); EXPECT_EQ(num_ok, kBatchSize);
} }

View File

@ -31,10 +31,12 @@ class WeightInitializer {
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
template <size_t N> template <size_t N>
void operator()(const char* name, std::array<float, N>& tensor) { void operator()(const char* name, CompressedArray<float, N>& tensor) {
float* data = tensor.data();
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
tensor[i] = dist_(gen_); data[i] = dist_(gen_);
} }
tensor.set_scale(1.0f);
} }
private: private:
std::normal_distribution<float> dist_; std::normal_distribution<float> dist_;
@ -45,11 +47,12 @@ template <typename TConfig>
struct RandInitWeightsT { struct RandInitWeightsT {
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool, void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) const { std::mt19937& gen) const {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); auto& weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python // TODO(szabadka) Use the same weight initialization method as in the python
// version. // version.
WeightInitializer init(gen); WeightInitializer init(gen);
ForEachTensor1<float, TConfig>(init, weights); ForEachTensor1<TConfig>(init, weights);
} }
}; };
@ -62,18 +65,23 @@ class AdamUpdater {
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {} norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
template <size_t kCapacity> template <size_t kCapacity>
void operator()(const char* name, const std::array<float, kCapacity>& grad, void operator()(const char* name,
std::array<float, kCapacity>& weights, const CompressedArray<float, kCapacity>& grad,
std::array<float, kCapacity>& grad_m, CompressedArray<float, kCapacity>& weights,
std::array<float, kCapacity>& grad_v) { CompressedArray<float, kCapacity>& grad_m,
CompressedArray<float, kCapacity>& grad_v) {
const float* HWY_RESTRICT g = grad.data();
float* HWY_RESTRICT w = weights.data();
float* HWY_RESTRICT m = grad_m.data();
float* HWY_RESTRICT v = grad_v.data();
for (size_t i = 0; i < kCapacity; ++i) { for (size_t i = 0; i < kCapacity; ++i) {
grad_m[i] *= beta1_; m[i] *= beta1_;
grad_m[i] += cbeta1_ * grad[i]; m[i] += cbeta1_ * g[i];
grad_v[i] *= beta2_; v[i] *= beta2_;
grad_v[i] += cbeta2_ * grad[i] * grad[i]; v[i] += cbeta2_ * g[i] * g[i];
const float mhat = grad_m[i] * norm1_; const float mhat = m[i] * norm1_;
const float vhat = grad_v[i] * norm2_; const float vhat = v[i] * norm2_;
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
} }
} }
@ -94,13 +102,13 @@ struct AdamUpdateT {
float beta2, float epsilon, size_t t, float beta2, float epsilon, size_t t,
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
const auto& grad = using TWeights = CompressedWeights<TConfig>;
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get()); const auto& grad = *reinterpret_cast<const TWeights*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); auto& weights = *reinterpret_cast<TWeights*>(weights_u8.get());
auto& grad_m = *reinterpret_cast<WeightsF<TConfig>*>(grad_m_u8.get()); auto& grad_m = *reinterpret_cast<TWeights*>(grad_m_u8.get());
auto& grad_v = *reinterpret_cast<WeightsF<TConfig>*>(grad_v_u8.get()); auto& grad_v = *reinterpret_cast<TWeights*>(grad_v_u8.get());
AdamUpdater updater(alpha, beta1, beta2, epsilon, t); AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
ForEachTensor4<float, TConfig>(updater, grad, weights, grad_m, grad_v); ForEachTensor4<TConfig>(updater, grad, weights, grad_m, grad_v);
} }
}; };
@ -109,17 +117,17 @@ struct AdamUpdateT {
void RandInitWeights(Model model_type, Type weight_type, void RandInitWeights(Model model_type, Type weight_type,
const ByteStorageT& weights, hwy::ThreadPool& pool, const ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen) { std::mt19937& gen) {
CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights, HWY_ASSERT(weight_type == Type::kF32);
pool, gen); CallForModel<float, RandInitWeightsT>(model_type, weights, pool, gen);
} }
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
float alpha, float beta1, float beta2, float epsilon, size_t t, float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m, const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool) { const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha, HWY_ASSERT(weight_type == Type::kF32);
beta1, beta2, epsilon, t, weights, grad_m, CallForModel<float, AdamUpdateT>(model_type, grad, alpha, beta1, beta2,
grad_v, pool); epsilon, t, weights, grad_m, grad_v, pool);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -41,10 +41,162 @@
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/profiler.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
namespace {
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
#define READ_WEIGHTS(name) \
do { \
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
} while (0)
#define SCALE_WEIGHTS(name) \
do { \
if (ok && !kDryRunFread && scale_for_compression) { \
weights->scales[scale_pos++] = \
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
} \
} while (0)
template <typename TConfig>
struct LoadRawWeightsT {
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression) const {
PROFILER_ZONE("Startup.LoadWeights");
if (!checkpoint.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0;
FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
}
bool ok = true;
uint64_t total_size = 0;
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
if (layer == -1) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
} else {
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1,
"embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin.linear_x_w);
READ_WEIGHTS(griffin.linear_x_biases);
READ_WEIGHTS(griffin.linear_y_w);
READ_WEIGHTS(griffin.linear_y_biases);
READ_WEIGHTS(griffin.linear_out_w);
READ_WEIGHTS(griffin.linear_out_biases);
READ_WEIGHTS(griffin.conv_w);
READ_WEIGHTS(griffin.conv_biases);
READ_WEIGHTS(griffin.gate_w);
READ_WEIGHTS(griffin.gate_biases);
READ_WEIGHTS(griffin.a);
SCALE_WEIGHTS(griffin.linear_x_w);
SCALE_WEIGHTS(griffin.linear_y_w);
SCALE_WEIGHTS(griffin.linear_out_w);
SCALE_WEIGHTS(griffin.gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) {
READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale);
}
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
}
if (!ok) {
HWY_ABORT(
"Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
}
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr));
if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
}
};
#undef READ_WEIGHTS
#undef SCALE_WEIGHTS
} // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression) {
return CallForModelAndWeight<LoadRawWeightsT>(
model_type, weight_type, weights, pool, scale_for_compression);
}
struct Args : public ArgsBase<Args> { struct Args : public ArgsBase<Args> {
static constexpr size_t kDefaultNumThreads = ~size_t{0}; static constexpr size_t kDefaultNumThreads = ~size_t{0};

View File

@ -723,8 +723,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
} }
template <class TConfig> template <class TConfig>
const WeightsT<TConfig>& GetWeights(const ByteStorageT& weights_u8) { const CompressedWeights<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
return *reinterpret_cast<const WeightsT<TConfig>*>(weights_u8.get()); return *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
} }
template <class TConfig, size_t kBatchSize> template <class TConfig, size_t kBatchSize>
@ -741,7 +741,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache, const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info, hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8); const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations = auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8); GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
auto& activations = GetActivations<TConfig, 1>(decode_u8); auto& activations = GetActivations<TConfig, 1>(decode_u8);
@ -884,7 +884,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
tokenizer_(tokenizer_path), tokenizer_(tokenizer_path),
model_type_(model_type), model_type_(model_type),
weight_type_(weight_type) { weight_type_(weight_type) {
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool); weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type); prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type); decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
} }
@ -895,8 +895,9 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
tokenizer_(std::move(tokenizer)), tokenizer_(std::move(tokenizer)),
model_type_(model_type), model_type_(model_type),
weight_type_(weight_type) { weight_type_(weight_type) {
weights_u8_ = HWY_ASSERT(weight_type == Type::kF32);
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool); weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
model_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type); prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type); decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
} }

View File

@ -29,157 +29,6 @@
namespace gcpp { namespace gcpp {
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
namespace {
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
#define READ_WEIGHTS(name) \
do { \
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
} while (0)
#define SCALE_WEIGHTS(name) \
do { \
if (ok && !kDryRunFread && scale_for_compression) { \
weights->scales[scale_pos++] = \
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
} \
} while (0)
template <typename TConfig>
struct LoadRawWeightsT {
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression) const {
PROFILER_ZONE("Startup.LoadWeights");
if (!checkpoint.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0;
FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
}
bool ok = true;
uint64_t total_size = 0;
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
if (layer == -1) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
} else {
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1,
"embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin.linear_x_w);
READ_WEIGHTS(griffin.linear_x_biases);
READ_WEIGHTS(griffin.linear_y_w);
READ_WEIGHTS(griffin.linear_y_biases);
READ_WEIGHTS(griffin.linear_out_w);
READ_WEIGHTS(griffin.linear_out_biases);
READ_WEIGHTS(griffin.conv_w);
READ_WEIGHTS(griffin.conv_biases);
READ_WEIGHTS(griffin.gate_w);
READ_WEIGHTS(griffin.gate_biases);
READ_WEIGHTS(griffin.a);
SCALE_WEIGHTS(griffin.linear_x_w);
SCALE_WEIGHTS(griffin.linear_y_w);
SCALE_WEIGHTS(griffin.linear_out_w);
SCALE_WEIGHTS(griffin.gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) {
READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale);
}
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
}
if (!ok) {
HWY_ABORT(
"Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
}
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr));
if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
}
};
#undef READ_WEIGHTS
#undef SCALE_WEIGHTS
} // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression) {
return CallForModelAndWeight<LoadRawWeightsT>(
model_type, weight_type, weights, pool, scale_for_compression);
}
namespace { namespace {
template <class TConfig> template <class TConfig>
struct LoadCompressedWeightsT { struct LoadCompressedWeightsT {
@ -234,16 +83,6 @@ ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
weights, pool); weights, pool);
} }
ByteStorageT LoadWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool) {
if constexpr (kWeightsAreCompressed) {
return LoadCompressedWeights(weights, model_type, weight_type, pool);
} else {
return LoadRawWeights(weights, model_type, weight_type, pool,
/*scale_for_compression=*/false);
}
}
namespace { namespace {
void LogVec(const char* name, const float* data, size_t len) { void LogVec(const char* name, const float* data, size_t len) {
hwy::Stats stats; hwy::Stats stats;
@ -257,7 +96,7 @@ void LogVec(const char* name, const float* data, size_t len) {
class WeightLogger { class WeightLogger {
public: public:
template <size_t N> template <size_t N>
void operator()(const char* name, const std::array<float, N>& tensor) { void operator()(const char* name, const CompressedArray<float, N>& tensor) {
LogVec(name, tensor.data(), N); LogVec(name, tensor.data(), N);
total_weights += N; total_weights += N;
} }
@ -268,9 +107,9 @@ template <typename TConfig>
struct LogWeightStatsT { struct LogWeightStatsT {
void operator()(const ByteStorageT& weights_u8) const { void operator()(const ByteStorageT& weights_u8) const {
const auto& weights = const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); *reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
WeightLogger logger; WeightLogger logger;
ForEachTensor1<float, TConfig>(logger, weights); ForEachTensor1<TConfig>(logger, weights);
printf("%-20s %12zu\n", "Total", logger.total_weights); printf("%-20s %12zu\n", "Total", logger.total_weights);
} }
}; };
@ -278,7 +117,8 @@ struct LogWeightStatsT {
void LogWeightStats(gcpp::Model model_type, Type weight_type, void LogWeightStats(gcpp::Model model_type, Type weight_type,
const ByteStorageT& weights) { const ByteStorageT& weights) {
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights); HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, LogWeightStatsT>(model_type, weights);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -25,9 +25,6 @@
namespace gcpp { namespace gcpp {
// Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true;
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Uncompressed // Uncompressed
@ -213,11 +210,16 @@ struct CompressedLayerPointers {
template <class TConfig> template <class TConfig>
struct CompressedWeights { struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned. // No ctor/dtor, allocated via AllocateAligned.
using Weight = typename TConfig::Weight;
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim> using WeightF32OrInputT =
hwy::If<hwy::IsSame<Weight, float>(), float, EmbedderInputT>;
CompressedArray<WeightF32OrInputT, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding; embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale; using WeightF32OrBF16 =
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
CompressedArray<WeightF32OrBF16, TConfig::kModelDim> final_norm_scale;
// Must be last so that the other arrays remain aligned. // Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs; CompressedLayerPointers<TConfig> c_layer_ptrs;
@ -233,10 +235,6 @@ struct CompressedWeights {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Interface // Interface
template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
WeightsF<TConfig>>;
// TODO: can we use TConfig::Weight instead of T? // TODO: can we use TConfig::Weight instead of T?
template <typename T, typename TConfig> template <typename T, typename TConfig>
struct AllocateWeights { struct AllocateWeights {
@ -256,6 +254,17 @@ struct AllocateWeightsF {
} }
}; };
template <typename TConfig>
struct AllocateCompressedWeights {
ByteStorageT operator()(hwy::ThreadPool& pool) const {
using TWeights = CompressedWeights<TConfig>;
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
return weights_u8;
}
};
template <typename T, typename TConfig> template <typename T, typename TConfig>
struct ZeroInitWeights { struct ZeroInitWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
@ -277,6 +286,20 @@ struct ZeroInitWeightsF {
} }
}; };
template <typename TConfig>
struct ZeroInitCompressedWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
CompressedWeights<TConfig>& w =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights.get());
hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
}
}
};
template <typename T, typename TConfig> template <typename T, typename TConfig>
struct CopyWeights { struct CopyWeights {
void operator()(Weights<T, TConfig>& dst, void operator()(Weights<T, TConfig>& dst,
@ -295,12 +318,9 @@ void operator()(Weights<T, TConfig>& dst,
template <class TConfig> template <class TConfig>
struct DeleteLayersPtrs { struct DeleteLayersPtrs {
void operator()(ByteStorageT& weights_u8) const { void operator()(ByteStorageT& weights_u8) const {
auto* weights = reinterpret_cast<WeightsT<TConfig>*>(weights_u8.get()); auto* weights =
if constexpr (kWeightsAreCompressed) { reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>(); weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
} else {
weights->layer_ptrs.~LayerPointers<float, TConfig>();
}
} }
}; };
@ -330,14 +350,8 @@ class WeightsWrapper {
Weights<T, TConfig>* weights_; Weights<T, TConfig>* weights_;
}; };
// For use by compress_weights.cc. ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
ByteStorageT LoadRawWeights(const Path& weights, Model model_type, Type weight_type, hwy::ThreadPool& pool);
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression);
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
ByteStorageT LoadWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool);
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights); void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
@ -467,62 +481,62 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
} }
template <typename T, typename TConfig, class Func> template <typename TConfig, class Func>
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) { void ForEachTensor1(Func& func, const CompressedWeights<TConfig>& weights1) {
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(1) GEMMA_CALL_ALL_LAYER_FUNC(1)
} }
} }
template <typename T, typename TConfig, class Func> template <typename TConfig, class Func>
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) { void ForEachTensor1(Func& func, CompressedWeights<TConfig>& weights1) {
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(1) GEMMA_CALL_ALL_LAYER_FUNC(1)
} }
} }
template <typename T, typename TConfig, class Func> template <typename TConfig, class Func>
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1, void ForEachTensor2(Func& func, const CompressedWeights<TConfig>& weights1,
Weights<T, TConfig>& weights2) { CompressedWeights<TConfig>& weights2) {
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx); CompressedLayer<TConfig>& layer2 = *weights2.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(2) GEMMA_CALL_ALL_LAYER_FUNC(2)
} }
} }
template <typename T, typename TConfig, class Func> template <typename TConfig, class Func>
void ForEachTensor4(Func& func, const Weights<T, TConfig>& weights1, void ForEachTensor4(Func& func, const CompressedWeights<TConfig>& weights1,
Weights<T, TConfig>& weights2, CompressedWeights<TConfig>& weights2,
Weights<T, TConfig>& weights3, CompressedWeights<TConfig>& weights3,
Weights<T, TConfig>& weights4) { CompressedWeights<TConfig>& weights4) {
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding); GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale); GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
char name_buf[16]; char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx]; auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx); const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx); const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx); CompressedLayer<TConfig>& layer2 = *weights2.GetLayer(idx);
LayerF<TConfig>& layer3 = *weights3.GetLayer(idx); CompressedLayer<TConfig>& layer3 = *weights3.GetLayer(idx);
LayerF<TConfig>& layer4 = *weights4.GetLayer(idx); CompressedLayer<TConfig>& layer4 = *weights4.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(4) GEMMA_CALL_ALL_LAYER_FUNC(4)
} }
} }