Merge pull request #224 from szabadka:cleanup

PiperOrigin-RevId: 641922102
This commit is contained in:
Copybara-Service 2024-06-10 09:11:13 -07:00
commit 49d814b519
10 changed files with 311 additions and 294 deletions

View File

@ -28,7 +28,6 @@
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -51,12 +50,12 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
template <size_t kCols, size_t kRows>
void MatMulVJP(const std::array<float, kRows * kCols>& weights,
const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens,
std::array<float, kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kCols
float* HWY_RESTRICT grad_w, // kRows * kCols,
float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -72,12 +71,12 @@ void MatMulVJP(const std::array<float, kRows * kCols>& weights,
template <size_t kHeads, size_t kCols, size_t kRows>
void MultiHeadMatMulVJP(
const std::array<float, kHeads * kRows * kCols>& weights,
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
const float* HWY_RESTRICT weights, // kHeads * kRows * kCols
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens,
std::array<float, kHeads * kRows * kCols>& grad_w,
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -166,12 +165,12 @@ static HWY_NOINLINE void InputEmbeddingVJP(
}
}
template <typename TConfig>
void LayerVJP(const Layer<float, TConfig>& weights,
template <typename TConfig, template<typename> typename LayerT>
void LayerVJP(const LayerT<TConfig>& weights,
const ForwardLayer<float, TConfig>& forward,
const float* HWY_RESTRICT next_layer_grad,
size_t num_tokens,
Layer<float, TConfig>& grad,
LayerT<TConfig>& grad,
ForwardLayer<float, TConfig>& backward,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
@ -184,8 +183,8 @@ void LayerVJP(const Layer<float, TConfig>& weights,
HWY_ASSERT(num_tokens <= kSeqLen);
MatMulVJP<kFFHiddenDim, kModelDim>(
weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad,
num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(),
weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad,
num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -210,9 +209,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
}
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
weights.gating_einsum_w,
weights.gating_einsum_w.data(),
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
num_tokens, grad.gating_einsum_w,
num_tokens, grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
forward.attention_out.data(),
@ -230,9 +229,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0]));
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
weights.attn_vec_einsum_w, forward.att_out.data(),
weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(), num_tokens,
grad.attn_vec_einsum_w, backward.att_out.data(), pool);
grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool);
for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -293,9 +292,9 @@ void LayerVJP(const Layer<float, TConfig>& weights,
}
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
weights.qkv_einsum_w, forward.pre_att_rms_out.data(),
backward.qkv.data(), num_tokens,
grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool);
weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), num_tokens,
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(),
forward.input.data(),
backward.pre_att_rms_out.data(),
@ -345,11 +344,12 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
}
}
template <typename TConfig>
template <typename TConfig, template<typename> typename WeightsT,
template<typename> typename LayerT>
void CrossEntropyLossBackwardPass(const Prompt& prompt,
const Weights<float, TConfig>& weights,
const WeightsT<TConfig>& weights,
const ForwardPass<float, TConfig>& forward,
Weights<float, TConfig>& grad,
WeightsT<TConfig>& grad,
ForwardPass<float, TConfig>& backward,
hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize;
@ -379,9 +379,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
}
MatMulVJP<kModelDim, kVocabSize>(
weights.embedder_input_embedding, forward.final_norm_output.data(),
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
backward.logits.data(), num_tokens,
grad.embedder_input_embedding, backward.final_norm_output.data(),
grad.embedder_input_embedding.data(), backward.final_norm_output.data(),
pool);
RMSNormVJP(weights.final_norm_scale.data(),
@ -398,8 +398,9 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
LayerVJP<TConfig, LayerT>(
*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,

View File

@ -42,13 +42,14 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
ByteStorageT& grad_u8,
ByteStorageT& backward_u8,
hwy::ThreadPool& pool) {
using TWeights = WeightsF<TConfig>;
using TWeights = CompressedWeights<TConfig>;
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool);
CrossEntropyLossBackwardPass<TConfig, CompressedWeights, CompressedLayer>(
prompt, weights, forward, grad, backward, pool);
}
void CrossEntropyLossBackwardPassT(Model model,

View File

@ -84,8 +84,8 @@ void TestMatMulVJP() {
};
hwy::ZeroBytes(&grad, sizeof(grad));
MatMulVJP<kCols, kRows>(weights, x.data(), dy.data(), kTokens,
grad, dx.data(), pool);
MatMulVJP<kCols, kRows>(weights.data(), x.data(), dy.data(), kTokens,
grad.data(), dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
@ -130,7 +130,8 @@ void TestMultiHeadMatMulVJP() {
hwy::ZeroBytes(&grad, sizeof(grad));
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool);
weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(),
pool);
TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__);
TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__);
@ -235,7 +236,7 @@ void TestEndToEnd() {
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.clear();
CrossEntropyLossBackwardPass(
CrossEntropyLossBackwardPass<TestConfig, WeightsF, LayerF>(
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
pool);

View File

@ -41,11 +41,12 @@ float CrossEntropyLossForwardPass(const Prompt& prompt,
ByteStorageT& forward_u8,
hwy::ThreadPool& pool) {
const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
auto& forward =
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
return CrossEntropyLossForwardPass<TConfig, WeightsF, LayerF>(
prompt.tokens, prompt.context_size, weights, forward, pool);
return
CrossEntropyLossForwardPass<TConfig, CompressedWeights, CompressedLayer>(
prompt.tokens, prompt.context_size, weights, forward, pool);
}
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,

View File

@ -34,19 +34,17 @@
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
if (kWeightsAreCompressed) return;
hwy::ThreadPool pool(0);
std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY;
Type weight_type = Type::kF32;
ByteStorageT grad =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT grad_m =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT grad_v =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool);
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool);
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool);
ByteStorageT forward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
ByteStorageT backward =
@ -88,10 +86,10 @@ TEST(OptimizeTest, GradientDescent) {
};
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_m,
pool);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad_v,
pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(
model_type, weight_type, grad_m, pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(
model_type, weight_type, grad_v, pool);
printf("Initial weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights());
@ -109,8 +107,8 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
CallForModelAndWeight<ZeroInitWeightsF>(model_type, weight_type, grad,
pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(
model_type, weight_type, grad, pool);
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
@ -139,7 +137,7 @@ TEST(OptimizeTest, GradientDescent) {
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights());
EXPECT_LT(steps, 200);
EXPECT_LT(steps, 300);
EXPECT_EQ(num_ok, kBatchSize);
}

View File

@ -31,10 +31,12 @@ class WeightInitializer {
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
template <size_t N>
void operator()(const char* name, std::array<float, N>& tensor) {
void operator()(const char* name, CompressedArray<float, N>& tensor) {
float* data = tensor.data();
for (size_t i = 0; i < N; ++i) {
tensor[i] = dist_(gen_);
data[i] = dist_(gen_);
}
tensor.set_scale(1.0f);
}
private:
std::normal_distribution<float> dist_;
@ -45,11 +47,12 @@ template <typename TConfig>
struct RandInitWeightsT {
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) const {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
auto& weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python
// version.
WeightInitializer init(gen);
ForEachTensor1<float, TConfig>(init, weights);
ForEachTensor1<TConfig>(init, weights);
}
};
@ -62,18 +65,23 @@ class AdamUpdater {
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
template <size_t kCapacity>
void operator()(const char* name, const std::array<float, kCapacity>& grad,
std::array<float, kCapacity>& weights,
std::array<float, kCapacity>& grad_m,
std::array<float, kCapacity>& grad_v) {
void operator()(const char* name,
const CompressedArray<float, kCapacity>& grad,
CompressedArray<float, kCapacity>& weights,
CompressedArray<float, kCapacity>& grad_m,
CompressedArray<float, kCapacity>& grad_v) {
const float* HWY_RESTRICT g = grad.data();
float* HWY_RESTRICT w = weights.data();
float* HWY_RESTRICT m = grad_m.data();
float* HWY_RESTRICT v = grad_v.data();
for (size_t i = 0; i < kCapacity; ++i) {
grad_m[i] *= beta1_;
grad_m[i] += cbeta1_ * grad[i];
grad_v[i] *= beta2_;
grad_v[i] += cbeta2_ * grad[i] * grad[i];
const float mhat = grad_m[i] * norm1_;
const float vhat = grad_v[i] * norm2_;
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
m[i] *= beta1_;
m[i] += cbeta1_ * g[i];
v[i] *= beta2_;
v[i] += cbeta2_ * g[i] * g[i];
const float mhat = m[i] * norm1_;
const float vhat = v[i] * norm2_;
w[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
}
}
@ -94,13 +102,13 @@ struct AdamUpdateT {
float beta2, float epsilon, size_t t,
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
const auto& grad =
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
auto& grad_m = *reinterpret_cast<WeightsF<TConfig>*>(grad_m_u8.get());
auto& grad_v = *reinterpret_cast<WeightsF<TConfig>*>(grad_v_u8.get());
using TWeights = CompressedWeights<TConfig>;
const auto& grad = *reinterpret_cast<const TWeights*>(grad_u8.get());
auto& weights = *reinterpret_cast<TWeights*>(weights_u8.get());
auto& grad_m = *reinterpret_cast<TWeights*>(grad_m_u8.get());
auto& grad_v = *reinterpret_cast<TWeights*>(grad_v_u8.get());
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
ForEachTensor4<float, TConfig>(updater, grad, weights, grad_m, grad_v);
ForEachTensor4<TConfig>(updater, grad, weights, grad_m, grad_v);
}
};
@ -109,17 +117,17 @@ struct AdamUpdateT {
void RandInitWeights(Model model_type, Type weight_type,
const ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen) {
CallForModelAndWeight<RandInitWeightsT>(model_type, weight_type, weights,
pool, gen);
HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, RandInitWeightsT>(model_type, weights, pool, gen);
}
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
CallForModelAndWeight<AdamUpdateT>(model_type, weight_type, grad, alpha,
beta1, beta2, epsilon, t, weights, grad_m,
grad_v, pool);
HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, AdamUpdateT>(model_type, grad, alpha, beta1, beta2,
epsilon, t, weights, grad_m, grad_v, pool);
}
} // namespace gcpp

View File

@ -41,10 +41,162 @@
#include "gemma/weights.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
namespace {
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
#define READ_WEIGHTS(name) \
do { \
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
} while (0)
#define SCALE_WEIGHTS(name) \
do { \
if (ok && !kDryRunFread && scale_for_compression) { \
weights->scales[scale_pos++] = \
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
} \
} while (0)
template <typename TConfig>
struct LoadRawWeightsT {
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression) const {
PROFILER_ZONE("Startup.LoadWeights");
if (!checkpoint.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0;
FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
}
bool ok = true;
uint64_t total_size = 0;
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
if (layer == -1) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
} else {
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1,
"embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin.linear_x_w);
READ_WEIGHTS(griffin.linear_x_biases);
READ_WEIGHTS(griffin.linear_y_w);
READ_WEIGHTS(griffin.linear_y_biases);
READ_WEIGHTS(griffin.linear_out_w);
READ_WEIGHTS(griffin.linear_out_biases);
READ_WEIGHTS(griffin.conv_w);
READ_WEIGHTS(griffin.conv_biases);
READ_WEIGHTS(griffin.gate_w);
READ_WEIGHTS(griffin.gate_biases);
READ_WEIGHTS(griffin.a);
SCALE_WEIGHTS(griffin.linear_x_w);
SCALE_WEIGHTS(griffin.linear_y_w);
SCALE_WEIGHTS(griffin.linear_out_w);
SCALE_WEIGHTS(griffin.gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) {
READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale);
}
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
}
if (!ok) {
HWY_ABORT(
"Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
}
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr));
if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
}
};
#undef READ_WEIGHTS
#undef SCALE_WEIGHTS
} // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression) {
return CallForModelAndWeight<LoadRawWeightsT>(
model_type, weight_type, weights, pool, scale_for_compression);
}
struct Args : public ArgsBase<Args> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};

View File

@ -723,8 +723,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
}
template <class TConfig>
const WeightsT<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
return *reinterpret_cast<const WeightsT<TConfig>*>(weights_u8.get());
const CompressedWeights<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
return *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
}
template <class TConfig, size_t kBatchSize>
@ -741,7 +741,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info,
LayersOutputT* layers_output) {
const WeightsT<TConfig>& weights = GetWeights<TConfig>(weights_u8);
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
auto& activations = GetActivations<TConfig, 1>(decode_u8);
@ -884,7 +884,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
tokenizer_(tokenizer_path),
model_type_(model_type),
weight_type_(weight_type) {
weights_u8_ = LoadWeights(weights, model_type, weight_type, pool);
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
}
@ -895,8 +895,9 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
tokenizer_(std::move(tokenizer)),
model_type_(model_type),
weight_type_(weight_type) {
weights_u8_ =
CallForModelAndWeight<AllocateWeightsF>(model_type, weight_type, pool);
HWY_ASSERT(weight_type == Type::kF32);
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
model_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
}

View File

@ -29,157 +29,6 @@
namespace gcpp {
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
namespace {
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
#define READ_WEIGHTS(name) \
do { \
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
} while (0)
#define SCALE_WEIGHTS(name) \
do { \
if (ok && !kDryRunFread && scale_for_compression) { \
weights->scales[scale_pos++] = \
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
} \
} while (0)
template <typename TConfig>
struct LoadRawWeightsT {
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression) const {
PROFILER_ZONE("Startup.LoadWeights");
if (!checkpoint.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0;
FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
}
bool ok = true;
uint64_t total_size = 0;
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
if (layer == -1) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
} else {
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1,
"embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin.linear_x_w);
READ_WEIGHTS(griffin.linear_x_biases);
READ_WEIGHTS(griffin.linear_y_w);
READ_WEIGHTS(griffin.linear_y_biases);
READ_WEIGHTS(griffin.linear_out_w);
READ_WEIGHTS(griffin.linear_out_biases);
READ_WEIGHTS(griffin.conv_w);
READ_WEIGHTS(griffin.conv_biases);
READ_WEIGHTS(griffin.gate_w);
READ_WEIGHTS(griffin.gate_biases);
READ_WEIGHTS(griffin.a);
SCALE_WEIGHTS(griffin.linear_x_w);
SCALE_WEIGHTS(griffin.linear_y_w);
SCALE_WEIGHTS(griffin.linear_out_w);
SCALE_WEIGHTS(griffin.gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) {
READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale);
}
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
}
if (!ok) {
HWY_ABORT(
"Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
}
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr));
if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
}
};
#undef READ_WEIGHTS
#undef SCALE_WEIGHTS
} // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression) {
return CallForModelAndWeight<LoadRawWeightsT>(
model_type, weight_type, weights, pool, scale_for_compression);
}
namespace {
template <class TConfig>
struct LoadCompressedWeightsT {
@ -234,16 +83,6 @@ ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
weights, pool);
}
ByteStorageT LoadWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool) {
if constexpr (kWeightsAreCompressed) {
return LoadCompressedWeights(weights, model_type, weight_type, pool);
} else {
return LoadRawWeights(weights, model_type, weight_type, pool,
/*scale_for_compression=*/false);
}
}
namespace {
void LogVec(const char* name, const float* data, size_t len) {
hwy::Stats stats;
@ -257,7 +96,7 @@ void LogVec(const char* name, const float* data, size_t len) {
class WeightLogger {
public:
template <size_t N>
void operator()(const char* name, const std::array<float, N>& tensor) {
void operator()(const char* name, const CompressedArray<float, N>& tensor) {
LogVec(name, tensor.data(), N);
total_weights += N;
}
@ -268,9 +107,9 @@ template <typename TConfig>
struct LogWeightStatsT {
void operator()(const ByteStorageT& weights_u8) const {
const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
WeightLogger logger;
ForEachTensor1<float, TConfig>(logger, weights);
ForEachTensor1<TConfig>(logger, weights);
printf("%-20s %12zu\n", "Total", logger.total_weights);
}
};
@ -278,7 +117,8 @@ struct LogWeightStatsT {
void LogWeightStats(gcpp::Model model_type, Type weight_type,
const ByteStorageT& weights) {
CallForModelAndWeight<LogWeightStatsT>(model_type, weight_type, weights);
HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, LogWeightStatsT>(model_type, weights);
}
} // namespace gcpp

View File

@ -25,9 +25,6 @@
namespace gcpp {
// Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true;
// ----------------------------------------------------------------------------
// Uncompressed
@ -213,11 +210,16 @@ struct CompressedLayerPointers {
template <class TConfig>
struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned.
using Weight = typename TConfig::Weight;
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
using WeightF32OrInputT =
hwy::If<hwy::IsSame<Weight, float>(), float, EmbedderInputT>;
CompressedArray<WeightF32OrInputT, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
using WeightF32OrBF16 =
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>;
CompressedArray<WeightF32OrBF16, TConfig::kModelDim> final_norm_scale;
// Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs;
@ -233,10 +235,6 @@ struct CompressedWeights {
// ----------------------------------------------------------------------------
// Interface
template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
WeightsF<TConfig>>;
// TODO: can we use TConfig::Weight instead of T?
template <typename T, typename TConfig>
struct AllocateWeights {
@ -256,6 +254,17 @@ struct AllocateWeightsF {
}
};
template <typename TConfig>
struct AllocateCompressedWeights {
ByteStorageT operator()(hwy::ThreadPool& pool) const {
using TWeights = CompressedWeights<TConfig>;
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
return weights_u8;
}
};
template <typename T, typename TConfig>
struct ZeroInitWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
@ -277,6 +286,20 @@ struct ZeroInitWeightsF {
}
};
template <typename TConfig>
struct ZeroInitCompressedWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
CompressedWeights<TConfig>& w =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights.get());
hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
}
}
};
template <typename T, typename TConfig>
struct CopyWeights {
void operator()(Weights<T, TConfig>& dst,
@ -295,12 +318,9 @@ void operator()(Weights<T, TConfig>& dst,
template <class TConfig>
struct DeleteLayersPtrs {
void operator()(ByteStorageT& weights_u8) const {
auto* weights = reinterpret_cast<WeightsT<TConfig>*>(weights_u8.get());
if constexpr (kWeightsAreCompressed) {
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
} else {
weights->layer_ptrs.~LayerPointers<float, TConfig>();
}
auto* weights =
reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
}
};
@ -330,14 +350,8 @@ class WeightsWrapper {
Weights<T, TConfig>* weights_;
};
// For use by compress_weights.cc.
ByteStorageT LoadRawWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool,
bool scale_for_compression);
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
ByteStorageT LoadWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool);
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool);
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
@ -467,62 +481,62 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \
}
template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, const Weights<T, TConfig>& weights1) {
template <typename TConfig, class Func>
void ForEachTensor1(Func& func, const CompressedWeights<TConfig>& weights1) {
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(1)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor1(Func& func, Weights<T, TConfig>& weights1) {
template <typename TConfig, class Func>
void ForEachTensor1(Func& func, CompressedWeights<TConfig>& weights1) {
GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(1)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
Weights<T, TConfig>& weights2) {
template <typename TConfig, class Func>
void ForEachTensor2(Func& func, const CompressedWeights<TConfig>& weights1,
CompressedWeights<TConfig>& weights2) {
GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
CompressedLayer<TConfig>& layer2 = *weights2.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(2)
}
}
template <typename T, typename TConfig, class Func>
void ForEachTensor4(Func& func, const Weights<T, TConfig>& weights1,
Weights<T, TConfig>& weights2,
Weights<T, TConfig>& weights3,
Weights<T, TConfig>& weights4) {
template <typename TConfig, class Func>
void ForEachTensor4(Func& func, const CompressedWeights<TConfig>& weights1,
CompressedWeights<TConfig>& weights2,
CompressedWeights<TConfig>& weights3,
CompressedWeights<TConfig>& weights4) {
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
LayerF<TConfig>& layer3 = *weights3.GetLayer(idx);
LayerF<TConfig>& layer4 = *weights4.GetLayer(idx);
const CompressedLayer<TConfig>& layer1 = *weights1.GetLayer(idx);
CompressedLayer<TConfig>& layer2 = *weights2.GetLayer(idx);
CompressedLayer<TConfig>& layer3 = *weights3.GetLayer(idx);
CompressedLayer<TConfig>& layer4 = *weights4.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(4)
}
}