mirror of https://github.com/google/gemma.cpp.git
Merge pull request #224 from szabadka:cleanup
PiperOrigin-RevId: 641922102
This commit is contained in:
commit
49d814b519
|
|
@ -28,7 +28,6 @@
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/weights.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -51,11 +50,11 @@ namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
template <size_t kCols, size_t kRows>
|
template <size_t kCols, size_t kRows>
|
||||||
void MatMulVJP(const std::array<float, kRows * kCols>& weights,
|
void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
|
||||||
const float* HWY_RESTRICT x, // num_tokens * kCols
|
const float* HWY_RESTRICT x, // num_tokens * kCols
|
||||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||||
size_t num_tokens,
|
size_t num_tokens,
|
||||||
std::array<float, kRows * kCols>& grad_w,
|
float* HWY_RESTRICT grad_w, // kRows * kCols,
|
||||||
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
|
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
|
||||||
|
|
@ -72,11 +71,11 @@ void MatMulVJP(const std::array<float, kRows * kCols>& weights,
|
||||||
|
|
||||||
template <size_t kHeads, size_t kCols, size_t kRows>
|
template <size_t kHeads, size_t kCols, size_t kRows>
|
||||||
void MultiHeadMatMulVJP(
|
void MultiHeadMatMulVJP(
|
||||||
const std::array<float, kHeads * kRows * kCols>& weights,
|
const float* HWY_RESTRICT weights, // kHeads * kRows * kCols
|
||||||
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
|
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
|
||||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||||
size_t num_tokens,
|
size_t num_tokens,
|
||||||
std::array<float, kHeads * kRows * kCols>& grad_w,
|
float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols
|
||||||
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
|
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
|
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
|
||||||
|
|
@ -166,12 +165,12 @@ static HWY_NOINLINE void InputEmbeddingVJP(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig, template<typename> typename LayerT>
|
||||||
void LayerVJP(const Layer<float, TConfig>& weights,
|
void LayerVJP(const LayerT<TConfig>& weights,
|
||||||
const ForwardLayer<float, TConfig>& forward,
|
const ForwardLayer<float, TConfig>& forward,
|
||||||
const float* HWY_RESTRICT next_layer_grad,
|
const float* HWY_RESTRICT next_layer_grad,
|
||||||
size_t num_tokens,
|
size_t num_tokens,
|
||||||
Layer<float, TConfig>& grad,
|
LayerT<TConfig>& grad,
|
||||||
ForwardLayer<float, TConfig>& backward,
|
ForwardLayer<float, TConfig>& backward,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||||
|
|
@ -184,8 +183,8 @@ void LayerVJP(const Layer<float, TConfig>& weights,
|
||||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||||
|
|
||||||
MatMulVJP<kFFHiddenDim, kModelDim>(
|
MatMulVJP<kFFHiddenDim, kModelDim>(
|
||||||
weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad,
|
weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad,
|
||||||
num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(),
|
num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
|
||||||
pool);
|
pool);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
|
|
@ -210,9 +209,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
|
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
|
||||||
weights.gating_einsum_w,
|
weights.gating_einsum_w.data(),
|
||||||
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
|
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
|
||||||
num_tokens, grad.gating_einsum_w,
|
num_tokens, grad.gating_einsum_w.data(),
|
||||||
backward.bf_pre_ffw_rms_out.data(), pool);
|
backward.bf_pre_ffw_rms_out.data(), pool);
|
||||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
|
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
|
||||||
forward.attention_out.data(),
|
forward.attention_out.data(),
|
||||||
|
|
@ -230,9 +229,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
|
||||||
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
|
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
|
||||||
|
|
||||||
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
|
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
|
||||||
weights.attn_vec_einsum_w, forward.att_out.data(),
|
weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||||
backward.attention_out.data(), num_tokens,
|
backward.attention_out.data(), num_tokens,
|
||||||
grad.attn_vec_einsum_w, backward.att_out.data(), pool);
|
grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool);
|
||||||
|
|
||||||
for (size_t head = 0; head < kHeads; ++head) {
|
for (size_t head = 0; head < kHeads; ++head) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
|
|
@ -293,9 +292,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
|
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
|
||||||
weights.qkv_einsum_w, forward.pre_att_rms_out.data(),
|
weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||||
backward.qkv.data(), num_tokens,
|
backward.qkv.data(), num_tokens,
|
||||||
grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool);
|
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
|
||||||
RMSNormVJP(weights.pre_attention_norm_scale.data(),
|
RMSNormVJP(weights.pre_attention_norm_scale.data(),
|
||||||
forward.input.data(),
|
forward.input.data(),
|
||||||
backward.pre_att_rms_out.data(),
|
backward.pre_att_rms_out.data(),
|
||||||
|
|
@ -345,11 +344,12 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig, template<typename> typename WeightsT,
|
||||||
|
template<typename> typename LayerT>
|
||||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
const Weights<float, TConfig>& weights,
|
const WeightsT<TConfig>& weights,
|
||||||
const ForwardPass<float, TConfig>& forward,
|
const ForwardPass<float, TConfig>& forward,
|
||||||
Weights<float, TConfig>& grad,
|
WeightsT<TConfig>& grad,
|
||||||
ForwardPass<float, TConfig>& backward,
|
ForwardPass<float, TConfig>& backward,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||||
|
|
@ -379,9 +379,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP<kModelDim, kVocabSize>(
|
MatMulVJP<kModelDim, kVocabSize>(
|
||||||
weights.embedder_input_embedding, forward.final_norm_output.data(),
|
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
|
||||||
backward.logits.data(), num_tokens,
|
backward.logits.data(), num_tokens,
|
||||||
grad.embedder_input_embedding, backward.final_norm_output.data(),
|
grad.embedder_input_embedding.data(), backward.final_norm_output.data(),
|
||||||
pool);
|
pool);
|
||||||
|
|
||||||
RMSNormVJP(weights.final_norm_scale.data(),
|
RMSNormVJP(weights.final_norm_scale.data(),
|
||||||
|
|
@ -398,7 +398,8 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
float* next_layer_grad = layer + 1 < kLayers
|
float* next_layer_grad = layer + 1 < kLayers
|
||||||
? backward.layers[layer + 1].input.data()
|
? backward.layers[layer + 1].input.data()
|
||||||
: backward.final_layer_output.data();
|
: backward.final_layer_output.data();
|
||||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
LayerVJP<TConfig, LayerT>(
|
||||||
|
*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
|
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,13 +42,14 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
ByteStorageT& grad_u8,
|
ByteStorageT& grad_u8,
|
||||||
ByteStorageT& backward_u8,
|
ByteStorageT& backward_u8,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
using TWeights = WeightsF<TConfig>;
|
using TWeights = CompressedWeights<TConfig>;
|
||||||
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
|
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
|
||||||
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
|
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
|
||||||
using TAct = ForwardPass<float, TConfig>;
|
using TAct = ForwardPass<float, TConfig>;
|
||||||
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
|
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
|
||||||
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
|
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
|
||||||
CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool);
|
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
|
||||||
|
prompt, weights, forward, grad, backward, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CrossEntropyLossBackwardPassT(Model model,
|
void CrossEntropyLossBackwardPassT(Model model,
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,8 @@ void TestMatMulVJP() {
|
||||||
};
|
};
|
||||||
|
|
||||||
hwy::ZeroBytes(&grad, sizeof(grad));
|
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||||
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
|
MatMulVJP<kCols, kRows>(weights.data(), x.data(), dy.data(), kTokens,
|
||||||
grad, dx.data(), pool);
|
grad.data(), dx.data(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
||||||
|
|
@ -130,7 +130,8 @@ void TestMultiHeadMatMulVJP() {
|
||||||
|
|
||||||
hwy::ZeroBytes(&grad, sizeof(grad));
|
hwy::ZeroBytes(&grad, sizeof(grad));
|
||||||
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
|
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
|
||||||
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
|
weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(),
|
||||||
|
pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
|
||||||
|
|
||||||
|
|
@ -235,7 +236,7 @@ void TestEndToEnd() {
|
||||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||||
|
|
||||||
grad.clear();
|
grad.clear();
|
||||||
CrossEntropyLossBackwardPass(
|
CrossEntropyLossBackwardPass<TestConfig, WeightsF, LayerF>(
|
||||||
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
|
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
|
||||||
pool);
|
pool);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,10 +41,11 @@ float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
ByteStorageT& forward_u8,
|
ByteStorageT& forward_u8,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
const auto& weights =
|
const auto& weights =
|
||||||
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
auto& forward =
|
auto& forward =
|
||||||
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
|
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
|
||||||
return CrossEntropyLossForwardPass<TConfig, WeightsF, LayerF>(
|
return
|
||||||
|
CrossEntropyLossForwardPass<TConfig, CompressedWeights, CompressedLayer>(
|
||||||
prompt.tokens, prompt.context_size, weights, forward, pool);
|
prompt.tokens, prompt.context_size, weights, forward, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,19 +34,17 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
TEST(OptimizeTest, GradientDescent) {
|
TEST(OptimizeTest, GradientDescent) {
|
||||||
if (kWeightsAreCompressed) return;
|
|
||||||
|
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
|
|
||||||
Model model_type = Model::GEMMA_TINY;
|
Model model_type = Model::GEMMA_TINY;
|
||||||
Type weight_type = Type::kF32;
|
Type weight_type = Type::kF32;
|
||||||
ByteStorageT grad =
|
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
model_type, weight_type, pool);
|
||||||
ByteStorageT grad_m =
|
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
model_type, weight_type, pool);
|
||||||
ByteStorageT grad_v =
|
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
model_type, weight_type, pool);
|
||||||
ByteStorageT forward =
|
ByteStorageT forward =
|
||||||
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
||||||
ByteStorageT backward =
|
ByteStorageT backward =
|
||||||
|
|
@ -88,10 +86,10 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
};
|
};
|
||||||
|
|
||||||
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
|
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
|
||||||
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m,
|
CallForModelAndWeight<ZeroInitCompressedWeights>(
|
||||||
pool);
|
model_type, weight_type, grad_m, pool);
|
||||||
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v,
|
CallForModelAndWeight<ZeroInitCompressedWeights>(
|
||||||
pool);
|
model_type, weight_type, grad_v, pool);
|
||||||
|
|
||||||
printf("Initial weights:\n");
|
printf("Initial weights:\n");
|
||||||
LogWeightStats(model_type, weight_type, gemma.Weights());
|
LogWeightStats(model_type, weight_type, gemma.Weights());
|
||||||
|
|
@ -109,8 +107,8 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
size_t num_ok;
|
size_t num_ok;
|
||||||
for (; steps < 1000000; ++steps) {
|
for (; steps < 1000000; ++steps) {
|
||||||
std::mt19937 sgen(42);
|
std::mt19937 sgen(42);
|
||||||
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad,
|
CallForModelAndWeight<ZeroInitCompressedWeights>(
|
||||||
pool);
|
model_type, weight_type, grad, pool);
|
||||||
float total_loss = 0.0f;
|
float total_loss = 0.0f;
|
||||||
num_ok = 0;
|
num_ok = 0;
|
||||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||||
|
|
@ -139,7 +137,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
printf("Num steps: %zu\n", steps);
|
printf("Num steps: %zu\n", steps);
|
||||||
printf("Final weights:\n");
|
printf("Final weights:\n");
|
||||||
LogWeightStats(model_type, weight_type, gemma.Weights());
|
LogWeightStats(model_type, weight_type, gemma.Weights());
|
||||||
EXPECT_LT(steps, 200);
|
EXPECT_LT(steps, 300);
|
||||||
EXPECT_EQ(num_ok, kBatchSize);
|
EXPECT_EQ(num_ok, kBatchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,10 +31,12 @@ class WeightInitializer {
|
||||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||||
|
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
void operator()(const char* name, std::array<float, N>& tensor) {
|
void operator()(const char* name, CompressedArray<float, N>& tensor) {
|
||||||
|
float* data = tensor.data();
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
tensor[i] = dist_(gen_);
|
data[i] = dist_(gen_);
|
||||||
}
|
}
|
||||||
|
tensor.set_scale(1.0f);
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
std::normal_distribution<float> dist_;
|
std::normal_distribution<float> dist_;
|
||||||
|
|
@ -45,11 +47,12 @@ template <typename TConfig>
|
||||||
struct RandInitWeightsT {
|
struct RandInitWeightsT {
|
||||||
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||||
std::mt19937& gen) const {
|
std::mt19937& gen) const {
|
||||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
auto& weights =
|
||||||
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
// TODO(szabadka) Use the same weight initialization method as in the python
|
// TODO(szabadka) Use the same weight initialization method as in the python
|
||||||
// version.
|
// version.
|
||||||
WeightInitializer init(gen);
|
WeightInitializer init(gen);
|
||||||
ForEachTensor1<float, TConfig>(init, weights);
|
ForEachTensor1<TConfig>(init, weights);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -62,18 +65,23 @@ class AdamUpdater {
|
||||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
|
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
|
||||||
|
|
||||||
template <size_t kCapacity>
|
template <size_t kCapacity>
|
||||||
void operator()(const char* name, const std::array<float, kCapacity>& grad,
|
void operator()(const char* name,
|
||||||
std::array<float, kCapacity>& weights,
|
const CompressedArray<float, kCapacity>& grad,
|
||||||
std::array<float, kCapacity>& grad_m,
|
CompressedArray<float, kCapacity>& weights,
|
||||||
std::array<float, kCapacity>& grad_v) {
|
CompressedArray<float, kCapacity>& grad_m,
|
||||||
|
CompressedArray<float, kCapacity>& grad_v) {
|
||||||
|
const float* HWY_RESTRICT g = grad.data();
|
||||||
|
float* HWY_RESTRICT w = weights.data();
|
||||||
|
float* HWY_RESTRICT m = grad_m.data();
|
||||||
|
float* HWY_RESTRICT v = grad_v.data();
|
||||||
for (size_t i = 0; i < kCapacity; ++i) {
|
for (size_t i = 0; i < kCapacity; ++i) {
|
||||||
grad_m[i] *= beta1_;
|
m[i] *= beta1_;
|
||||||
grad_m[i] += cbeta1_ * grad[i];
|
m[i] += cbeta1_ * g[i];
|
||||||
grad_v[i] *= beta2_;
|
v[i] *= beta2_;
|
||||||
grad_v[i] += cbeta2_ * grad[i] * grad[i];
|
v[i] += cbeta2_ * g[i] * g[i];
|
||||||
const float mhat = grad_m[i] * norm1_;
|
const float mhat = m[i] * norm1_;
|
||||||
const float vhat = grad_v[i] * norm2_;
|
const float vhat = v[i] * norm2_;
|
||||||
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -94,13 +102,13 @@ struct AdamUpdateT {
|
||||||
float beta2, float epsilon, size_t t,
|
float beta2, float epsilon, size_t t,
|
||||||
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
|
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
|
||||||
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
|
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
|
||||||
const auto& grad =
|
using TWeights = CompressedWeights<TConfig>;
|
||||||
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
|
const auto& grad = *reinterpret_cast<const TWeights*>(grad_u8.get());
|
||||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
auto& weights = *reinterpret_cast<TWeights*>(weights_u8.get());
|
||||||
auto& grad_m = *reinterpret_cast<WeightsF<TConfig>*>(grad_m_u8.get());
|
auto& grad_m = *reinterpret_cast<TWeights*>(grad_m_u8.get());
|
||||||
auto& grad_v = *reinterpret_cast<WeightsF<TConfig>*>(grad_v_u8.get());
|
auto& grad_v = *reinterpret_cast<TWeights*>(grad_v_u8.get());
|
||||||
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
||||||
ForEachTensor4<float, TConfig>(updater, grad, weights, grad_m, grad_v);
|
ForEachTensor4<TConfig>(updater, grad, weights, grad_m, grad_v);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -109,17 +117,17 @@ struct AdamUpdateT {
|
||||||
void RandInitWeights(Model model_type, Type weight_type,
|
void RandInitWeights(Model model_type, Type weight_type,
|
||||||
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||||
std::mt19937& gen) {
|
std::mt19937& gen) {
|
||||||
CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights,
|
HWY_ASSERT(weight_type == Type::kF32);
|
||||||
pool, gen);
|
CallForModel<float, RandInitWeightsT>(model_type, weights, pool, gen);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||||
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
||||||
CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha,
|
HWY_ASSERT(weight_type == Type::kF32);
|
||||||
beta1, beta2, epsilon, t, weights, grad_m,
|
CallForModel<float, AdamUpdateT>(model_type, grad, alpha, beta1, beta2,
|
||||||
grad_v, pool);
|
epsilon, t, weights, grad_m, grad_v, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -41,10 +41,162 @@
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Setting this to true disables fread() calls that read the model file.
|
||||||
|
constexpr bool kDryRunFread = false;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
float ScaleWeights(float* data, size_t len) {
|
||||||
|
float maxabs = 0.0;
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
maxabs = std::max(maxabs, std::abs(data[i]));
|
||||||
|
}
|
||||||
|
const float kMaxRange = 1.875f;
|
||||||
|
if (maxabs <= kMaxRange) {
|
||||||
|
return 1.0f;
|
||||||
|
}
|
||||||
|
const float scale = maxabs / kMaxRange;
|
||||||
|
const float inv_scale = 1.0f / scale;
|
||||||
|
for (size_t i = 0; i < len; ++i) {
|
||||||
|
data[i] *= inv_scale;
|
||||||
|
}
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define READ_WEIGHTS(name) \
|
||||||
|
do { \
|
||||||
|
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define SCALE_WEIGHTS(name) \
|
||||||
|
do { \
|
||||||
|
if (ok && !kDryRunFread && scale_for_compression) { \
|
||||||
|
weights->scales[scale_pos++] = \
|
||||||
|
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
struct LoadRawWeightsT {
|
||||||
|
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
|
||||||
|
bool scale_for_compression) const {
|
||||||
|
PROFILER_ZONE("Startup.LoadWeights");
|
||||||
|
if (!checkpoint.Exists()) {
|
||||||
|
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||||
|
checkpoint.path.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
|
||||||
|
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
|
|
||||||
|
size_t scale_pos = 0;
|
||||||
|
FILE* fptr;
|
||||||
|
if constexpr (kDryRunFread) {
|
||||||
|
fprintf(stderr, "Dry-Run, not reading model-file.\n");
|
||||||
|
} else {
|
||||||
|
fptr = fopen(checkpoint.path.c_str(), "rb");
|
||||||
|
if (fptr == nullptr) {
|
||||||
|
HWY_ABORT("Failed to open model file %s - does it exist?",
|
||||||
|
checkpoint.path.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool ok = true;
|
||||||
|
uint64_t total_size = 0;
|
||||||
|
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
|
||||||
|
if (layer == -1) {
|
||||||
|
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
|
||||||
|
size, name);
|
||||||
|
}
|
||||||
|
if constexpr (!kDryRunFread) {
|
||||||
|
ok &= 1 == fread(var, size, 1, fptr);
|
||||||
|
total_size += size;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
do_fread(&(weights->embedder_input_embedding), -1,
|
||||||
|
"embedder_input_embedding",
|
||||||
|
sizeof(weights->embedder_input_embedding));
|
||||||
|
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
||||||
|
sizeof(weights->final_norm_scale));
|
||||||
|
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer];
|
||||||
|
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
|
||||||
|
|
||||||
|
// Make sure we don't have uninitialized memory.
|
||||||
|
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
||||||
|
if (type == LayerAttentionType::kGemma) {
|
||||||
|
READ_WEIGHTS(attn_vec_einsum_w);
|
||||||
|
READ_WEIGHTS(qkv_einsum_w);
|
||||||
|
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||||
|
SCALE_WEIGHTS(qkv_einsum_w);
|
||||||
|
} else {
|
||||||
|
READ_WEIGHTS(griffin.linear_x_w);
|
||||||
|
READ_WEIGHTS(griffin.linear_x_biases);
|
||||||
|
READ_WEIGHTS(griffin.linear_y_w);
|
||||||
|
READ_WEIGHTS(griffin.linear_y_biases);
|
||||||
|
READ_WEIGHTS(griffin.linear_out_w);
|
||||||
|
READ_WEIGHTS(griffin.linear_out_biases);
|
||||||
|
READ_WEIGHTS(griffin.conv_w);
|
||||||
|
READ_WEIGHTS(griffin.conv_biases);
|
||||||
|
READ_WEIGHTS(griffin.gate_w);
|
||||||
|
READ_WEIGHTS(griffin.gate_biases);
|
||||||
|
READ_WEIGHTS(griffin.a);
|
||||||
|
SCALE_WEIGHTS(griffin.linear_x_w);
|
||||||
|
SCALE_WEIGHTS(griffin.linear_y_w);
|
||||||
|
SCALE_WEIGHTS(griffin.linear_out_w);
|
||||||
|
SCALE_WEIGHTS(griffin.gate_w);
|
||||||
|
}
|
||||||
|
READ_WEIGHTS(gating_einsum_w);
|
||||||
|
READ_WEIGHTS(linear_w);
|
||||||
|
SCALE_WEIGHTS(gating_einsum_w);
|
||||||
|
SCALE_WEIGHTS(linear_w);
|
||||||
|
READ_WEIGHTS(pre_attention_norm_scale);
|
||||||
|
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||||
|
if (TConfig::kPostNormScale) {
|
||||||
|
READ_WEIGHTS(post_attention_norm_scale);
|
||||||
|
READ_WEIGHTS(post_ffw_norm_scale);
|
||||||
|
}
|
||||||
|
if (TConfig::kFFBiases) {
|
||||||
|
READ_WEIGHTS(ffw_gating_biases);
|
||||||
|
READ_WEIGHTS(ffw_output_biases);
|
||||||
|
}
|
||||||
|
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||||
|
type == LayerAttentionType::kGemma) {
|
||||||
|
READ_WEIGHTS(attention_output_biases);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ok) {
|
||||||
|
HWY_ABORT(
|
||||||
|
"Failed to read from %s - might be a directory, or too small? "
|
||||||
|
"expected size: %d kB",
|
||||||
|
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
|
||||||
|
}
|
||||||
|
if (!kDryRunFread) {
|
||||||
|
HWY_ASSERT(0 == fclose(fptr));
|
||||||
|
if (scale_for_compression) {
|
||||||
|
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return weights_u8;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#undef READ_WEIGHTS
|
||||||
|
#undef SCALE_WEIGHTS
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
||||||
|
Type weight_type, hwy::ThreadPool& pool,
|
||||||
|
bool scale_for_compression) {
|
||||||
|
return CallForModelAndWeight<LoadRawWeightsT>(
|
||||||
|
model_type, weight_type, weights, pool, scale_for_compression);
|
||||||
|
}
|
||||||
|
|
||||||
struct Args : public ArgsBase<Args> {
|
struct Args : public ArgsBase<Args> {
|
||||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -723,8 +723,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
const WeightsT<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
|
const CompressedWeights<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
|
||||||
return *reinterpret_cast<const WeightsT<TConfig>*>(weights_u8.get());
|
return *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig, size_t kBatchSize>
|
template <class TConfig, size_t kBatchSize>
|
||||||
|
|
@ -741,7 +741,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
|
||||||
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
|
||||||
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
hwy::ThreadPool& pool, TimingInfo& timing_info,
|
||||||
LayersOutputT* layers_output) {
|
LayersOutputT* layers_output) {
|
||||||
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
|
||||||
auto& prefill_activations =
|
auto& prefill_activations =
|
||||||
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
|
||||||
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
auto& activations = GetActivations<TConfig, 1>(decode_u8);
|
||||||
|
|
@ -884,7 +884,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
tokenizer_(tokenizer_path),
|
tokenizer_(tokenizer_path),
|
||||||
model_type_(model_type),
|
model_type_(model_type),
|
||||||
weight_type_(weight_type) {
|
weight_type_(weight_type) {
|
||||||
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool);
|
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
|
||||||
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
||||||
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
||||||
}
|
}
|
||||||
|
|
@ -895,8 +895,9 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
||||||
tokenizer_(std::move(tokenizer)),
|
tokenizer_(std::move(tokenizer)),
|
||||||
model_type_(model_type),
|
model_type_(model_type),
|
||||||
weight_type_(weight_type) {
|
weight_type_(weight_type) {
|
||||||
weights_u8_ =
|
HWY_ASSERT(weight_type == Type::kF32);
|
||||||
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
|
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
|
||||||
|
model_type, pool);
|
||||||
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
|
||||||
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
170
gemma/weights.cc
170
gemma/weights.cc
|
|
@ -29,157 +29,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Setting this to true disables fread() calls that read the model file.
|
|
||||||
constexpr bool kDryRunFread = false;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
float ScaleWeights(float* data, size_t len) {
|
|
||||||
float maxabs = 0.0;
|
|
||||||
for (size_t i = 0; i < len; ++i) {
|
|
||||||
maxabs = std::max(maxabs, std::abs(data[i]));
|
|
||||||
}
|
|
||||||
const float kMaxRange = 1.875f;
|
|
||||||
if (maxabs <= kMaxRange) {
|
|
||||||
return 1.0f;
|
|
||||||
}
|
|
||||||
const float scale = maxabs / kMaxRange;
|
|
||||||
const float inv_scale = 1.0f / scale;
|
|
||||||
for (size_t i = 0; i < len; ++i) {
|
|
||||||
data[i] *= inv_scale;
|
|
||||||
}
|
|
||||||
return scale;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define READ_WEIGHTS(name) \
|
|
||||||
do { \
|
|
||||||
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define SCALE_WEIGHTS(name) \
|
|
||||||
do { \
|
|
||||||
if (ok && !kDryRunFread && scale_for_compression) { \
|
|
||||||
weights->scales[scale_pos++] = \
|
|
||||||
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
template <typename TConfig>
|
|
||||||
struct LoadRawWeightsT {
|
|
||||||
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
|
|
||||||
bool scale_for_compression) const {
|
|
||||||
PROFILER_ZONE("Startup.LoadWeights");
|
|
||||||
if (!checkpoint.Exists()) {
|
|
||||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
|
||||||
checkpoint.path.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
|
|
||||||
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
|
||||||
|
|
||||||
size_t scale_pos = 0;
|
|
||||||
FILE* fptr;
|
|
||||||
if constexpr (kDryRunFread) {
|
|
||||||
fprintf(stderr, "Dry-Run, not reading model-file.\n");
|
|
||||||
} else {
|
|
||||||
fptr = fopen(checkpoint.path.c_str(), "rb");
|
|
||||||
if (fptr == nullptr) {
|
|
||||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
|
||||||
checkpoint.path.c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bool ok = true;
|
|
||||||
uint64_t total_size = 0;
|
|
||||||
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
|
|
||||||
if (layer == -1) {
|
|
||||||
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
|
|
||||||
size, name);
|
|
||||||
}
|
|
||||||
if constexpr (!kDryRunFread) {
|
|
||||||
ok &= 1 == fread(var, size, 1, fptr);
|
|
||||||
total_size += size;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
do_fread(&(weights->embedder_input_embedding), -1,
|
|
||||||
"embedder_input_embedding",
|
|
||||||
sizeof(weights->embedder_input_embedding));
|
|
||||||
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
|
||||||
sizeof(weights->final_norm_scale));
|
|
||||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
|
||||||
auto type = TConfig::kLayerConfig[layer];
|
|
||||||
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
|
|
||||||
|
|
||||||
// Make sure we don't have uninitialized memory.
|
|
||||||
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
|
||||||
if (type == LayerAttentionType::kGemma) {
|
|
||||||
READ_WEIGHTS(attn_vec_einsum_w);
|
|
||||||
READ_WEIGHTS(qkv_einsum_w);
|
|
||||||
SCALE_WEIGHTS(attn_vec_einsum_w);
|
|
||||||
SCALE_WEIGHTS(qkv_einsum_w);
|
|
||||||
} else {
|
|
||||||
READ_WEIGHTS(griffin.linear_x_w);
|
|
||||||
READ_WEIGHTS(griffin.linear_x_biases);
|
|
||||||
READ_WEIGHTS(griffin.linear_y_w);
|
|
||||||
READ_WEIGHTS(griffin.linear_y_biases);
|
|
||||||
READ_WEIGHTS(griffin.linear_out_w);
|
|
||||||
READ_WEIGHTS(griffin.linear_out_biases);
|
|
||||||
READ_WEIGHTS(griffin.conv_w);
|
|
||||||
READ_WEIGHTS(griffin.conv_biases);
|
|
||||||
READ_WEIGHTS(griffin.gate_w);
|
|
||||||
READ_WEIGHTS(griffin.gate_biases);
|
|
||||||
READ_WEIGHTS(griffin.a);
|
|
||||||
SCALE_WEIGHTS(griffin.linear_x_w);
|
|
||||||
SCALE_WEIGHTS(griffin.linear_y_w);
|
|
||||||
SCALE_WEIGHTS(griffin.linear_out_w);
|
|
||||||
SCALE_WEIGHTS(griffin.gate_w);
|
|
||||||
}
|
|
||||||
READ_WEIGHTS(gating_einsum_w);
|
|
||||||
READ_WEIGHTS(linear_w);
|
|
||||||
SCALE_WEIGHTS(gating_einsum_w);
|
|
||||||
SCALE_WEIGHTS(linear_w);
|
|
||||||
READ_WEIGHTS(pre_attention_norm_scale);
|
|
||||||
READ_WEIGHTS(pre_ffw_norm_scale);
|
|
||||||
if (TConfig::kPostNormScale) {
|
|
||||||
READ_WEIGHTS(post_attention_norm_scale);
|
|
||||||
READ_WEIGHTS(post_ffw_norm_scale);
|
|
||||||
}
|
|
||||||
if (TConfig::kFFBiases) {
|
|
||||||
READ_WEIGHTS(ffw_gating_biases);
|
|
||||||
READ_WEIGHTS(ffw_output_biases);
|
|
||||||
}
|
|
||||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
|
||||||
type == LayerAttentionType::kGemma) {
|
|
||||||
READ_WEIGHTS(attention_output_biases);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!ok) {
|
|
||||||
HWY_ABORT(
|
|
||||||
"Failed to read from %s - might be a directory, or too small? "
|
|
||||||
"expected size: %d kB",
|
|
||||||
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
|
|
||||||
}
|
|
||||||
if (!kDryRunFread) {
|
|
||||||
HWY_ASSERT(0 == fclose(fptr));
|
|
||||||
if (scale_for_compression) {
|
|
||||||
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return weights_u8;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#undef READ_WEIGHTS
|
|
||||||
#undef SCALE_WEIGHTS
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
|
||||||
Type weight_type, hwy::ThreadPool& pool,
|
|
||||||
bool scale_for_compression) {
|
|
||||||
return CallForModelAndWeight<LoadRawWeightsT>(
|
|
||||||
model_type, weight_type, weights, pool, scale_for_compression);
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct LoadCompressedWeightsT {
|
struct LoadCompressedWeightsT {
|
||||||
|
|
@ -234,16 +83,6 @@ ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||||
weights, pool);
|
weights, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
ByteStorageT LoadWeights(const Path& weights, Model model_type,
|
|
||||||
Type weight_type, hwy::ThreadPool& pool) {
|
|
||||||
if constexpr (kWeightsAreCompressed) {
|
|
||||||
return LoadCompressedWeights(weights, model_type, weight_type, pool);
|
|
||||||
} else {
|
|
||||||
return LoadRawWeights(weights, model_type, weight_type, pool,
|
|
||||||
/*scale_for_compression=*/false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void LogVec(const char* name, const float* data, size_t len) {
|
void LogVec(const char* name, const float* data, size_t len) {
|
||||||
hwy::Stats stats;
|
hwy::Stats stats;
|
||||||
|
|
@ -257,7 +96,7 @@ void LogVec(const char* name, const float* data, size_t len) {
|
||||||
class WeightLogger {
|
class WeightLogger {
|
||||||
public:
|
public:
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
void operator()(const char* name, const std::array<float, N>& tensor) {
|
void operator()(const char* name, const CompressedArray<float, N>& tensor) {
|
||||||
LogVec(name, tensor.data(), N);
|
LogVec(name, tensor.data(), N);
|
||||||
total_weights += N;
|
total_weights += N;
|
||||||
}
|
}
|
||||||
|
|
@ -268,9 +107,9 @@ template <typename TConfig>
|
||||||
struct LogWeightStatsT {
|
struct LogWeightStatsT {
|
||||||
void operator()(const ByteStorageT& weights_u8) const {
|
void operator()(const ByteStorageT& weights_u8) const {
|
||||||
const auto& weights =
|
const auto& weights =
|
||||||
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
WeightLogger logger;
|
WeightLogger logger;
|
||||||
ForEachTensor1<float, TConfig>(logger, weights);
|
ForEachTensor1<TConfig>(logger, weights);
|
||||||
printf("%-20s %12zu\n", "Total", logger.total_weights);
|
printf("%-20s %12zu\n", "Total", logger.total_weights);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -278,7 +117,8 @@ struct LogWeightStatsT {
|
||||||
|
|
||||||
void LogWeightStats(gcpp::Model model_type, Type weight_type,
|
void LogWeightStats(gcpp::Model model_type, Type weight_type,
|
||||||
const ByteStorageT& weights) {
|
const ByteStorageT& weights) {
|
||||||
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights);
|
HWY_ASSERT(weight_type == Type::kF32);
|
||||||
|
CallForModel<float, LogWeightStatsT>(model_type, weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Setting this to false will load and use uncompressed weights.
|
|
||||||
constexpr bool kWeightsAreCompressed = true;
|
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Uncompressed
|
// Uncompressed
|
||||||
|
|
||||||
|
|
@ -213,11 +210,16 @@ struct CompressedLayerPointers {
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct CompressedWeights {
|
struct CompressedWeights {
|
||||||
// No ctor/dtor, allocated via AllocateAligned.
|
// No ctor/dtor, allocated via AllocateAligned.
|
||||||
|
using Weight = typename TConfig::Weight;
|
||||||
|
|
||||||
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
|
using WeightF32OrInputT =
|
||||||
|
hwy::If<hwy::IsSame<Weight, float>(), float, EmbedderInputT>;
|
||||||
|
CompressedArray<WeightF32OrInputT, TConfig::kVocabSize * TConfig::kModelDim>
|
||||||
embedder_input_embedding;
|
embedder_input_embedding;
|
||||||
|
|
||||||
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
|
using WeightF32OrBF16 =
|
||||||
|
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
|
||||||
|
CompressedArray<WeightF32OrBF16, TConfig::kModelDim> final_norm_scale;
|
||||||
|
|
||||||
// Must be last so that the other arrays remain aligned.
|
// Must be last so that the other arrays remain aligned.
|
||||||
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
||||||
|
|
@ -233,10 +235,6 @@ struct CompressedWeights {
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Interface
|
// Interface
|
||||||
|
|
||||||
template <class TConfig>
|
|
||||||
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
|
||||||
WeightsF<TConfig>>;
|
|
||||||
|
|
||||||
// TODO: can we use TConfig::Weight instead of T?
|
// TODO: can we use TConfig::Weight instead of T?
|
||||||
template <typename T, typename TConfig>
|
template <typename T, typename TConfig>
|
||||||
struct AllocateWeights {
|
struct AllocateWeights {
|
||||||
|
|
@ -256,6 +254,17 @@ struct AllocateWeightsF {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
struct AllocateCompressedWeights {
|
||||||
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
|
using TWeights = CompressedWeights<TConfig>;
|
||||||
|
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
||||||
|
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||||
|
new (&weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||||
|
return weights_u8;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename TConfig>
|
template <typename T, typename TConfig>
|
||||||
struct ZeroInitWeights {
|
struct ZeroInitWeights {
|
||||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||||
|
|
@ -277,6 +286,20 @@ struct ZeroInitWeightsF {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename TConfig>
|
||||||
|
struct ZeroInitCompressedWeights {
|
||||||
|
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||||
|
CompressedWeights<TConfig>& w =
|
||||||
|
*reinterpret_cast<CompressedWeights<TConfig>*>(weights.get());
|
||||||
|
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||||
|
sizeof(w.embedder_input_embedding));
|
||||||
|
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||||
|
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||||
|
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T, typename TConfig>
|
template <typename T, typename TConfig>
|
||||||
struct CopyWeights {
|
struct CopyWeights {
|
||||||
void operator()(Weights<T, TConfig>& dst,
|
void operator()(Weights<T, TConfig>& dst,
|
||||||
|
|
@ -295,12 +318,9 @@ void operator()(Weights<T, TConfig>& dst,
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
struct DeleteLayersPtrs {
|
struct DeleteLayersPtrs {
|
||||||
void operator()(ByteStorageT& weights_u8) const {
|
void operator()(ByteStorageT& weights_u8) const {
|
||||||
auto* weights = reinterpret_cast<WeightsT<TConfig>*>(weights_u8.get());
|
auto* weights =
|
||||||
if constexpr (kWeightsAreCompressed) {
|
reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||||
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||||
} else {
|
|
||||||
weights->layer_ptrs.~LayerPointers<float, TConfig>();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -330,13 +350,7 @@ class WeightsWrapper {
|
||||||
Weights<T, TConfig>* weights_;
|
Weights<T, TConfig>* weights_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// For use by compress_weights.cc.
|
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||||
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
|
|
||||||
Type weight_type, hwy::ThreadPool& pool,
|
|
||||||
bool scale_for_compression);
|
|
||||||
|
|
||||||
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
|
|
||||||
ByteStorageT LoadWeights(const Path& weights, Model model_type,
|
|
||||||
Type weight_type, hwy::ThreadPool& pool);
|
Type weight_type, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
||||||
|
|
@ -467,62 +481,62 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
|
||||||
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
|
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename TConfig, class Func>
|
||||||
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
|
void ForEachTensor1(Func& func, const CompressedWeights<TConfig>& weights1) {
|
||||||
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
||||||
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename TConfig, class Func>
|
||||||
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
|
void ForEachTensor1(Func& func, CompressedWeights<TConfig>& weights1) {
|
||||||
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
|
||||||
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
GEMMA_CALL_ALL_LAYER_FUNC(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename TConfig, class Func>
|
||||||
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
void ForEachTensor2(Func& func, const CompressedWeights<TConfig>& weights1,
|
||||||
Weights<T, TConfig>& weights2) {
|
CompressedWeights<TConfig>& weights2) {
|
||||||
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
|
||||||
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
CompressedLayer<TConfig>& layer2 = *weights2.GetLayer(idx);
|
||||||
GEMMA_CALL_ALL_LAYER_FUNC(2)
|
GEMMA_CALL_ALL_LAYER_FUNC(2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename TConfig, class Func>
|
template <typename TConfig, class Func>
|
||||||
void ForEachTensor4(Func& func, const Weights<T, TConfig>& weights1,
|
void ForEachTensor4(Func& func, const CompressedWeights<TConfig>& weights1,
|
||||||
Weights<T, TConfig>& weights2,
|
CompressedWeights<TConfig>& weights2,
|
||||||
Weights<T, TConfig>& weights3,
|
CompressedWeights<TConfig>& weights3,
|
||||||
Weights<T, TConfig>& weights4) {
|
CompressedWeights<TConfig>& weights4) {
|
||||||
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
|
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
|
||||||
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
|
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
|
||||||
char name_buf[16];
|
char name_buf[16];
|
||||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
auto type = TConfig::kLayerConfig[layer_idx];
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
const size_t idx = static_cast<size_t>(layer_idx);
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
CompressedLayer<TConfig>& layer2 = *weights2.GetLayer(idx);
|
||||||
LayerF<TConfig>& layer3 = *weights3.GetLayer(idx);
|
CompressedLayer<TConfig>& layer3 = *weights3.GetLayer(idx);
|
||||||
LayerF<TConfig>& layer4 = *weights4.GetLayer(idx);
|
CompressedLayer<TConfig>& layer4 = *weights4.GetLayer(idx);
|
||||||
GEMMA_CALL_ALL_LAYER_FUNC(4)
|
GEMMA_CALL_ALL_LAYER_FUNC(4)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue