Eliminated TConfig.

Changed CompressedLayer and CompressedWeights to be constructed with an instance of a LayerConfig and WeightsConfig respectively.
Added CompressedModel to remove ByteStorageT and get rid of most of the type casting, as well as allowing the default destructor to be used and work properly.
Adjusted WeightsWrapper and ForwardLayer etc to match.
The only remaining template arg is the weight type.
This enables all the instantiations to be deleted, apart from one per type.
It also enables (but not yet done) the config to be stored in the blob file instead of having to be specified separately.
Reduces the size of the gemma_lib and weights shared libraries by a factor of 4.3 and 3.2 respectively.

PiperOrigin-RevId: 686870060
This commit is contained in:
Ray Smith 2024-10-17 05:03:35 -07:00 committed by Copybara-Service
parent a4d6adbc43
commit 0d68555f87
68 changed files with 2810 additions and 2902 deletions

View File

@ -104,8 +104,6 @@ cc_test(
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":allocator", ":allocator",
":common",
":gemma_lib",
":ops", ":ops",
":test_util", ":test_util",
":threading", ":threading",
@ -183,7 +181,10 @@ cc_test(
cc_library( cc_library(
name = "common", name = "common",
srcs = ["gemma/common.cc"], srcs = [
"gemma/common.cc",
"gemma/configs.cc",
],
hdrs = [ hdrs = [
"gemma/common.h", "gemma/common.h",
"gemma/configs.h", "gemma/configs.h",
@ -195,12 +196,20 @@ cc_library(
], ],
) )
cc_test(
name = "configs_test",
srcs = ["gemma/configs_test.cc"],
deps = [
":common",
"@googletest//:gtest_main",
],
)
cc_library( cc_library(
name = "weights", name = "weights",
srcs = ["gemma/weights.cc"], srcs = ["gemma/weights.cc"],
hdrs = ["gemma/weights.h"], hdrs = ["gemma/weights.h"],
deps = [ deps = [
":allocator",
":common", ":common",
"//compression:compress", "//compression:compress",
"//compression:io", "//compression:io",
@ -219,7 +228,6 @@ cc_library(
":common", ":common",
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler", "@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor", "@com_google_sentencepiece//:sentencepiece_processor",
], ],
@ -239,30 +247,10 @@ cc_library(
name = "gemma_lib", name = "gemma_lib",
srcs = [ srcs = [
"gemma/gemma.cc", "gemma/gemma.cc",
"gemma/instantiations/27b_bf16.cc", "gemma/instantiations/bf16.cc",
"gemma/instantiations/27b_f32.cc", "gemma/instantiations/f32.cc",
"gemma/instantiations/27b_sfp.cc", "gemma/instantiations/nuq.cc",
"gemma/instantiations/2b_bf16.cc", "gemma/instantiations/sfp.cc",
"gemma/instantiations/2b_f32.cc",
"gemma/instantiations/2b_sfp.cc",
"gemma/instantiations/7b_bf16.cc",
"gemma/instantiations/7b_f32.cc",
"gemma/instantiations/7b_sfp.cc",
"gemma/instantiations/9b_bf16.cc",
"gemma/instantiations/9b_f32.cc",
"gemma/instantiations/9b_sfp.cc",
"gemma/instantiations/tiny_bf16.cc",
"gemma/instantiations/tiny_f32.cc",
"gemma/instantiations/tiny_sfp.cc",
"gemma/instantiations/gr2b_bf16.cc",
"gemma/instantiations/gr2b_f32.cc",
"gemma/instantiations/gr2b_sfp.cc",
"gemma/instantiations/gemma2_2b_bf16.cc",
"gemma/instantiations/gemma2_2b_f32.cc",
"gemma/instantiations/gemma2_2b_sfp.cc",
"gemma/instantiations/paligemma_224_bf16.cc",
"gemma/instantiations/paligemma_224_f32.cc",
"gemma/instantiations/paligemma_224_sfp.cc",
], ],
hdrs = [ hdrs = [
"gemma/activations.h", "gemma/activations.h",
@ -327,8 +315,6 @@ cc_library(
":threading", ":threading",
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
], ],
) )
@ -367,7 +353,6 @@ cc_test(
":benchmark_helper", ":benchmark_helper",
":common", ":common",
":gemma_lib", ":gemma_lib",
":tokenizer",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"@highway//:hwy", "@highway//:hwy",
"@highway//:hwy_test_util", "@highway//:hwy_test_util",
@ -396,7 +381,6 @@ cc_binary(
name = "single_benchmark", name = "single_benchmark",
srcs = ["evals/benchmark.cc"], srcs = ["evals/benchmark.cc"],
deps = [ deps = [
":app",
":args", ":args",
":benchmark_helper", ":benchmark_helper",
":common", ":common",
@ -405,7 +389,6 @@ cc_binary(
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:nanobenchmark", "@highway//:nanobenchmark",
"@highway//:thread_pool",
"@nlohmann_json//:json", "@nlohmann_json//:json",
], ],
) )
@ -429,13 +412,11 @@ cc_binary(
"evals/debug_prompt.cc", "evals/debug_prompt.cc",
], ],
deps = [ deps = [
":app",
":args", ":args",
":benchmark_helper", ":benchmark_helper",
":gemma_lib", ":gemma_lib",
"//compression:io", "//compression:io",
"@highway//:hwy", "@highway//:hwy",
"@highway//:thread_pool",
"@nlohmann_json//:json", "@nlohmann_json//:json",
], ],
) )
@ -444,7 +425,6 @@ cc_binary(
name = "gemma_mmlu", name = "gemma_mmlu",
srcs = ["evals/run_mmlu.cc"], srcs = ["evals/run_mmlu.cc"],
deps = [ deps = [
":app",
":args", ":args",
":benchmark_helper", ":benchmark_helper",
":gemma_lib", ":gemma_lib",
@ -488,7 +468,6 @@ cc_library(
deps = [ deps = [
":allocator", ":allocator",
":common", ":common",
":gemma_lib",
":ops", ":ops",
":prompt", ":prompt",
":weights", ":weights",
@ -508,7 +487,6 @@ cc_library(
"backprop/forward_scalar.h", "backprop/forward_scalar.h",
], ],
deps = [ deps = [
":allocator",
":common", ":common",
":prompt", ":prompt",
":weights", ":weights",
@ -525,7 +503,6 @@ cc_test(
"backprop/test_util.h", "backprop/test_util.h",
], ],
deps = [ deps = [
":allocator",
":backprop_scalar", ":backprop_scalar",
":common", ":common",
":prompt", ":prompt",
@ -599,6 +576,7 @@ cc_test(
":threading", ":threading",
":weights", ":weights",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//compression:sfp",
"@highway//:thread_pool", "@highway//:thread_pool",
], ],
) )

View File

@ -68,34 +68,15 @@ set(SOURCES
gemma/activations.h gemma/activations.h
gemma/common.cc gemma/common.cc
gemma/common.h gemma/common.h
gemma/configs.cc
gemma/configs.h gemma/configs.h
gemma/gemma-inl.h gemma/gemma-inl.h
gemma/gemma.cc gemma/gemma.cc
gemma/gemma.h gemma/gemma.h
gemma/instantiations/27b_bf16.cc gemma/instantiations/bf16.cc
gemma/instantiations/27b_f32.cc gemma/instantiations/f32.cc
gemma/instantiations/27b_sfp.cc gemma/instantiations/nuq.cc
gemma/instantiations/2b_bf16.cc gemma/instantiations/sfp.cc
gemma/instantiations/2b_f32.cc
gemma/instantiations/2b_sfp.cc
gemma/instantiations/7b_bf16.cc
gemma/instantiations/7b_f32.cc
gemma/instantiations/7b_sfp.cc
gemma/instantiations/9b_bf16.cc
gemma/instantiations/9b_f32.cc
gemma/instantiations/9b_sfp.cc
gemma/instantiations/gr2b_bf16.cc
gemma/instantiations/gr2b_f32.cc
gemma/instantiations/gr2b_sfp.cc
gemma/instantiations/tiny_bf16.cc
gemma/instantiations/tiny_f32.cc
gemma/instantiations/tiny_sfp.cc
gemma/instantiations/gemma2_2b_bf16.cc
gemma/instantiations/gemma2_2b_f32.cc
gemma/instantiations/gemma2_2b_sfp.cc
gemma/instantiations/paligemma_224_bf16.cc
gemma/instantiations/paligemma_224_f32.cc
gemma/instantiations/paligemma_224_sfp.cc
gemma/kv_cache.cc gemma/kv_cache.cc
gemma/kv_cache.h gemma/kv_cache.h
gemma/tokenizer.cc gemma/tokenizer.cc

View File

@ -18,32 +18,27 @@
#include <stddef.h> #include <stddef.h>
#include <array> #include <vector>
#include "compression/compress.h" // MatStorageT #include "compression/compress.h" // MatStorageT
#include "util/allocator.h" // ByteStorageT #include "gemma/configs.h" // ModelConfig
namespace gcpp { namespace gcpp {
template <typename T, typename TConfig> template <typename T>
struct ForwardLayer { struct ForwardLayer {
ForwardLayer() ForwardLayer(const LayerConfig& config, size_t seq_len)
: input("input", kSeqLen, kModelDim), : input("input", seq_len, config.model_dim),
pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim), pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim), qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
att("att", kSeqLen * kHeads, kSeqLen), att("att", seq_len * config.heads, seq_len),
att_out("att_out", kSeqLen * kHeads, kQKVDim), att_out("att_out", seq_len * config.heads, config.qkv_dim),
att_post1("att_post1", kSeqLen, kModelDim), att_post1("att_post1", seq_len, config.model_dim),
attention_out("attention_out", kSeqLen, kModelDim), attention_out("attention_out", seq_len, config.model_dim),
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim), bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2), ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {} ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
layer_config(config) {}
static constexpr size_t kSeqLen = TConfig::kSeqLen;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
MatStorageT<T> input; MatStorageT<T> input;
MatStorageT<T> pre_att_rms_out; MatStorageT<T> pre_att_rms_out;
@ -55,56 +50,30 @@ struct ForwardLayer {
MatStorageT<T> bf_pre_ffw_rms_out; MatStorageT<T> bf_pre_ffw_rms_out;
MatStorageT<T> ffw_hidden; MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated; MatStorageT<T> ffw_hidden_gated;
const LayerConfig& layer_config;
}; };
template <typename T, typename TConfig> template <typename T>
struct ForwardPass { struct ForwardPass {
ForwardPass() ForwardPass(const ModelConfig& config)
: final_layer_output("final_layer_output", kSeqLen, kModelDim), : final_layer_output("final_layer_output", config.seq_len,
final_norm_output("final_norm_output", kSeqLen, kModelDim), config.model_dim),
logits("logits", kSeqLen, kVocabSize), final_norm_output("final_norm_output", config.seq_len,
probs("probs", kSeqLen, kVocabSize) { config.model_dim),
} // prevents placement-new calling memset logits("logits", config.seq_len, config.vocab_size),
probs("probs", config.seq_len, config.vocab_size),
weights_config(config) {
for (const auto& layer_config : config.layer_configs) {
layers.emplace_back(layer_config, config.seq_len);
}
}
static constexpr size_t kSeqLen = TConfig::kSeqLen; std::vector<ForwardLayer<T>> layers;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kVocabSize = TConfig::kVocabSize;
static constexpr size_t kLayers = TConfig::kLayers;
std::array<ForwardLayer<T, TConfig>, kLayers> layers;
MatStorageT<T> final_layer_output; MatStorageT<T> final_layer_output;
MatStorageT<T> final_norm_output; MatStorageT<T> final_norm_output;
MatStorageT<T> logits; MatStorageT<T> logits;
MatStorageT<T> probs; MatStorageT<T> probs;
}; const ModelConfig& weights_config;
template <typename TConfig>
struct AllocateForwardPass {
ByteStorageT operator()() const {
ByteStorageT c_weights_u8 = AllocateSizeof<ForwardPass<float, TConfig>>();
auto* c_weights =
reinterpret_cast<ForwardPass<float, TConfig>*>(c_weights_u8.get());
new (c_weights) ForwardPass<float, TConfig>();
return c_weights_u8;
}
};
// Owns activations and undoes the type erasure of AllocateAligned.
template<typename T, typename TConfig>
class ActivationsWrapper {
using WrappedT = ForwardPass<T, TConfig>;
public:
ActivationsWrapper()
: data_(AllocateSizeof<WrappedT>()),
activations_(*(new(data_.get()) WrappedT())) {}
const WrappedT& get() const { return activations_; }
WrappedT& get() { return activations_; }
private:
ByteStorageT data_;
WrappedT& activations_;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -28,6 +28,7 @@
#include "backprop/activations.h" #include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -53,45 +54,41 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
template <size_t kCols, size_t kRows> HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
const float* HWY_RESTRICT x, // num_tokens * kCols const float* HWY_RESTRICT x, // num_tokens * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens, size_t cols, size_t rows, size_t num_tokens,
float* HWY_RESTRICT grad_w, // kRows * kCols, float* HWY_RESTRICT grad_w, // kRows * kCols,
float* HWY_RESTRICT grad_x, // num_tokens * kCols float* HWY_RESTRICT grad_x, // num_tokens * kCols
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t voffs = pos * kRows; const size_t voffs = pos * rows;
const size_t xoffs = pos * kCols; const size_t xoffs = pos * cols;
for (size_t j = 0; j < kRows; ++j) { for (size_t j = 0; j < rows; ++j) {
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols); MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols);
MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs], MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols);
kCols);
} }
} }
} }
template <size_t kHeads, size_t kCols, size_t kRows> HWY_INLINE void MultiHeadMatMulVJP(
void MultiHeadMatMulVJP( const float* HWY_RESTRICT weights, // heads * kRows * kCols
const float* HWY_RESTRICT weights, // kHeads * kRows * kCols const float* HWY_RESTRICT x, // num_tokens * heads * kCols
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
const float* HWY_RESTRICT v, // num_tokens * kRows const float* HWY_RESTRICT v, // num_tokens * kRows
size_t num_tokens, size_t heads, size_t cols, size_t rows, size_t num_tokens,
float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols float* HWY_RESTRICT grad_w, // heads * kRows * kCols
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0]));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t j = 0; j < kRows; ++j) { for (size_t j = 0; j < rows; ++j) {
for (size_t h = 0; h < kHeads; ++h) { for (size_t h = 0; h < heads; ++h) {
MulByConstAndAdd(v[pos * kRows + j], MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols],
&x[pos * kHeads * kCols + h * kCols], &grad_w[h * rows * cols + j * cols], cols);
&grad_w[h * kRows * kCols + j * kCols], kCols); MulByConstAndAdd(v[pos * rows + j],
MulByConstAndAdd(v[pos * kRows + j], &weights[h * rows * cols + j * cols],
&weights[h * kRows * kCols + j * kCols], &grad_x[pos * heads * cols + h * cols], cols);
&grad_x[pos * kHeads * kCols + h * kCols], kCols);
} }
} }
} }
@ -168,39 +165,39 @@ static HWY_NOINLINE void InputEmbeddingVJP(
} }
} }
template <typename TConfig, typename LayerT> template <typename T>
void LayerVJP(const LayerT& weights, void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<float, TConfig>& forward, const ForwardLayer<float>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerT& grad, ForwardLayer<float, TConfig>& backward, LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
const RowVectorBatch<float>& inv_timescale, const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim; const LayerConfig& config = weights.layer_config;
static constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t model_dim = config.model_dim;
static constexpr size_t kHeads = TConfig::kHeads; const size_t qkv_dim = config.qkv_dim;
static constexpr size_t kSeqLen = TConfig::kSeqLen; const size_t heads = config.heads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t seq_len = forward.input.Rows();
static const float kQueryScale = const size_t ff_hidden_dim = config.ff_hidden_dim;
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); const float query_scale =
HWY_ASSERT(num_tokens <= kSeqLen); static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
HWY_ASSERT(num_tokens <= seq_len);
MatMulVJP<kFFHiddenDim, kModelDim>( MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad, next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(), grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2; const size_t hidden_offset = pos * ff_hidden_dim * 2;
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim; const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
const float* HWY_RESTRICT b_out_gated = const float* HWY_RESTRICT b_out_gated =
backward.ffw_hidden_gated.data() + pos * kFFHiddenDim; backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim; float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
DF df; DF df;
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) {
const auto y = Load(df, f_out + i); const auto y = Load(df, f_out + i);
const auto x = Load(df, f_out_mul + i); const auto x = Load(df, f_out_mul + i);
const auto v = Load(df, b_out_gated + i); const auto v = Load(df, b_out_gated + i);
@ -209,101 +206,94 @@ void LayerVJP(const LayerT& weights,
} }
} }
MatMulVJP<kModelDim, kFFHiddenDim * 2>( MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
weights.gating_einsum_w.data(), backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
num_tokens, grad.gating_einsum_w.data(), num_tokens, grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), pool); backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(), RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
forward.attention_out.data(), backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
backward.bf_pre_ffw_rms_out.data(), grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
kModelDim, num_tokens, pool);
grad.pre_ffw_norm_scale.data(),
backward.attention_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * kModelDim, AddFrom(next_layer_grad + pos * model_dim,
backward.attention_out.data() + pos * kModelDim, kModelDim); backward.attention_out.data() + pos * model_dim, model_dim);
} }
backward.qkv.ZeroInit(); backward.qkv.ZeroInit();
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>( MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
weights.attn_vec_einsum_w.data(), forward.att_out.data(), backward.attention_out.data(), heads, qkv_dim, model_dim,
backward.attention_out.data(), num_tokens, num_tokens, grad.attn_vec_einsum_w.data(),
grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool); backward.att_out.data(), pool);
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
const float* HWY_RESTRICT b_att_out = const float* HWY_RESTRICT b_att_out =
backward.att_out.data() + (pos * kHeads + head) * kQKVDim; backward.att_out.data() + (pos * heads + head) * qkv_dim;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim); b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim); MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
} }
} }
} }
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
SoftmaxVJP(f_head_att, b_head_att, pos + 1); SoftmaxVJP(f_head_att, b_head_att, pos + 1);
} }
} }
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; const size_t aoffs = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim); MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim); MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
} }
} }
} }
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) { for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv = float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos); Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
} }
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_q = float* HWY_RESTRICT b_q =
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
MulByConst(kQueryScale, b_q, kQKVDim); MulByConst(query_scale, b_q, qkv_dim);
Rope(b_q, kQKVDim, inv_timescale.Const(), -pos); Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
} }
} }
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>( MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
backward.qkv.data(), num_tokens,
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(), RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
forward.input.data(), backward.pre_att_rms_out.data(), model_dim, num_tokens,
backward.pre_att_rms_out.data(), grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
kModelDim, num_tokens,
grad.pre_attention_norm_scale.data(),
backward.input.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(backward.attention_out.data() + pos * kModelDim, AddFrom(backward.attention_out.data() + pos * model_dim,
backward.input.data() + pos * kModelDim, kModelDim); backward.input.data() + pos * model_dim, model_dim);
} }
} }
@ -342,20 +332,22 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
} }
} }
template <typename TConfig, typename WeightsT, typename LayerT> template <typename T>
void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
const ForwardPass<float, TConfig>& forward, const ModelWeightsPtrs<T>& weights,
WeightsT& grad, const ForwardPass<float>& forward,
ForwardPass<float, TConfig>& backward, ModelWeightsPtrs<T>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; const ModelConfig& config = weights.weights_config;
static constexpr size_t kModelDim = TConfig::kModelDim; const size_t kVocabSize = config.vocab_size;
static constexpr size_t kLayers = TConfig::kLayers; const size_t model_dim = config.model_dim;
const float kEmbScaling = EmbeddingScaling<TConfig>(); const size_t kLayers = config.layer_configs.size();
static_assert(!TConfig::kAbsolutePE); const float kEmbScaling = EmbeddingScaling(model_dim);
static_assert(TConfig::kPostNorm == PostNormType::None); HWY_ASSERT(!config.absolute_pe);
static_assert(TConfig::kKVHeads == 1); HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
HWY_DASSERT(prompt.context_size > 0); HWY_DASSERT(prompt.context_size > 0);
HWY_DASSERT(prompt.context_size < prompt.tokens.size()); HWY_DASSERT(prompt.context_size < prompt.tokens.size());
@ -370,42 +362,38 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights,
kVocabSize); kVocabSize);
} }
if constexpr (TConfig::kFinalCap > 0.0f) { if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, kVocabSize); backward.logits.data() + pos * kVocabSize, kVocabSize);
} }
} }
MatMulVJP<kModelDim, kVocabSize>( MatMulVJP(weights.embedder_input_embedding.data(),
weights.embedder_input_embedding.data(), forward.final_norm_output.data(), forward.final_norm_output.data(), backward.logits.data(), model_dim,
backward.logits.data(), num_tokens, kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
grad.embedder_input_embedding.data(), backward.final_norm_output.data(), backward.final_norm_output.data(), pool);
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
backward.final_norm_output.data(), model_dim, num_tokens,
grad.final_norm_scale.data(), backward.final_layer_output.data(),
pool); pool);
RMSNormVJP(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(),
kModelDim, num_tokens,
grad.final_norm_scale.data(),
backward.final_layer_output.data(), pool);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) { for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
auto type = TConfig::kLayerConfig[layer]; auto layer_config = config.layer_configs[layer];
// TODO(szabadka) Implement Griffin layer vjp. // TODO(szabadka) Implement Griffin layer vjp.
HWY_ASSERT(type == LayerAttentionType::kGemma); HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
float* next_layer_grad = layer + 1 < kLayers float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data() ? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data(); : backward.final_layer_output.data();
LayerVJP<TConfig, LayerT>(*weights.GetLayer(layer), forward.layers[layer], LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
next_layer_grad, num_tokens, num_tokens, *grad.GetLayer(layer), backward.layers[layer],
*grad.GetLayer(layer), backward.layers[layer],
inv_timescale, pool); inv_timescale, pool);
} }
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
kEmbScaling, backward.layers[0].input.data(), kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim); grad.embedder_input_embedding.data(), model_dim);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -38,44 +38,15 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
template <typename TConfig> void CrossEntropyLossBackwardPassT(const Prompt& prompt,
void CrossEntropyLossBackwardPass(const Prompt& prompt, const ModelWeightsPtrs<float>& weights,
const ByteStorageT& weights_u8, const ForwardPass<float>& forward,
const ByteStorageT& forward_u8, ModelWeightsPtrs<float>& grad,
ByteStorageT& grad_u8, ForwardPass<float>& backward,
ByteStorageT& backward_u8,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
using TWeights = CompressedWeights<TConfig>; CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get()); inv_timescale, pool);
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
using TAct = ForwardPass<float, TConfig>;
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
CrossEntropyLossBackwardPass<TConfig, CompressedWeights<TConfig>,
CompressedLayer<TConfig>>(
prompt, weights, forward, grad, backward, inv_timescale, pool);
}
void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
const ByteStorageT& forward,
ByteStorageT& grad, ByteStorageT& backward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
prompt, weights, forward, grad, backward, inv_timescale, pool);
break;
case Model::GEMMA_TINY:
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, grad, backward, inv_timescale, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
@ -87,14 +58,15 @@ namespace gcpp {
HWY_EXPORT(CrossEntropyLossBackwardPassT); HWY_EXPORT(CrossEntropyLossBackwardPassT);
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ByteStorageT& weights, const ModelWeightsPtrs<float>& weights,
const ByteStorageT& forward, const ForwardPass<float>& forward,
ByteStorageT& grad, ByteStorageT& backward, ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
model, prompt, weights, forward, grad, backward, inv_timescale, pool); prompt, weights, forward, grad, backward, inv_timescale, pool);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -16,17 +16,19 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
#include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/weights.h"
#include "gemma/common.h" #include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ByteStorageT& weights, const ModelWeightsPtrs<float>& weights,
const ByteStorageT& forward, const ForwardPass<float>& forward,
ByteStorageT& grad, ByteStorageT& backward, ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);

View File

@ -125,65 +125,64 @@ void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
} }
} }
template <typename T>
template<typename T>
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
size_t num_tokens, size_t kHeads, size_t kQKVDim, size_t num_tokens, size_t kHeads, size_t qkv_dim,
size_t kSeqLen) { size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * (kHeads + 2) * kQKVDim; const size_t offset = pos * (kHeads + 2) * qkv_dim;
memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0])); memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0]));
} }
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim;
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; const size_t aoffs = head * seq_len + pos * kHeads * seq_len;
const T* q = qkv + qoffs; const T* q = qkv + qoffs;
const T* dout = doutput + aoffs; const T* dout = doutput + aoffs;
T* dq = dqkv + qoffs; T* dq = dqkv + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim;
const T* k = qkv + koffs; const T* k = qkv + koffs;
T* dk = dqkv + koffs; T* dk = dqkv + koffs;
MulByConstAndAddT(dout[pos2], k, dq, kQKVDim); MulByConstAndAddT(dout[pos2], k, dq, qkv_dim);
MulByConstAndAddT(dout[pos2], q, dk, kQKVDim); MulByConstAndAddT(dout[pos2], q, dk, qkv_dim);
} }
} }
} }
} }
template<typename T> template <typename T>
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads,
size_t kHeads, size_t kSeqLen) { size_t seq_len) {
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; size_t offset = pos * kHeads * seq_len + head * seq_len;
SoftmaxVJPT(y + offset, dy + offset, pos + 1); SoftmaxVJPT(y + offset, dy + offset, pos + 1);
memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
} }
} }
} }
template<typename T> template <typename T>
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
T* dqkv, T* dattention, size_t num_tokens, T* dqkv, T* dattention, size_t num_tokens, size_t kHeads,
size_t kHeads, size_t kQKVDim, size_t kSeqLen) { size_t qkv_dim, size_t seq_len) {
auto v_offset = [&](size_t pos) { auto v_offset = [&](size_t pos) {
return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim; return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim;
}; };
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0])); memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0]));
} }
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < kHeads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim; const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim;
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; const size_t aoffset = head * seq_len + pos * kHeads * seq_len;
const T* att = &attention[aoffset]; const T* att = &attention[aoffset];
const T* dout = &doutput[offset]; const T* dout = &doutput[offset];
T* datt = &dattention[aoffset]; T* datt = &dattention[aoffset];
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim); datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim);
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim); MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim);
} }
} }
} }
@ -199,77 +198,76 @@ void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
} }
} }
template <typename T, typename TConfig> template <typename T>
void LayerVJP(const CompressedLayer<TConfig>& weights, void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<T, TConfig>& forward, const T* dy, const ForwardLayer<T>& forward, const T* dy,
CompressedLayer<TConfig>& grad, LayerWeightsPtrs<T>& grad, ForwardLayer<T>& backward,
ForwardLayer<T, TConfig>& backward, size_t num_tokens) { size_t num_tokens) {
static constexpr size_t kModelDim = TConfig::kModelDim; const LayerConfig& layer_config = weights.layer_config;
static constexpr size_t kSeqLen = TConfig::kSeqLen; const size_t model_dim = layer_config.model_dim;
static constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t seq_len = forward.input.Rows();
static constexpr size_t kHeads = TConfig::kHeads; const size_t qkv_dim = layer_config.qkv_dim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t kHeads = layer_config.heads;
static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim)); const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(), grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
kModelDim, kFFHiddenDim, num_tokens); kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), grad.gating_einsum_w.data(), backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim, backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
num_tokens); num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(), backward.bf_pre_ffw_rms_out.data(),
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
kModelDim, num_tokens); model_dim, num_tokens);
AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim); AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(), backward.attention_out.data(),
grad.attn_vec_einsum_w.data(), grad.attn_vec_einsum_w.data(), backward.att_out.data(),
backward.att_out.data(), kHeads, model_dim, qkv_dim, num_tokens);
kHeads, kModelDim, kQKVDim, num_tokens);
MixByAttentionVJP(forward.qkv.data(), forward.att.data(), MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
backward.att_out.data(), backward.qkv.data(), backward.att_out.data(), backward.qkv.data(),
backward.att.data(), num_tokens, kHeads, kQKVDim, backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
kSeqLen);
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
num_tokens, kHeads, kSeqLen); seq_len);
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen); backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
} }
for (int pos = 0; pos < num_tokens; ++pos) { for (int pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
for (size_t h = 0; h <= kHeads; ++h) { for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * kQKVDim, kQKVDim, -pos); Rope(qkv + h * qkv_dim, qkv_dim, -pos);
} }
} }
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), grad.qkv_einsum_w.data(), backward.qkv.data(), grad.qkv_einsum_w.data(),
backward.pre_att_rms_out.data(), backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
(kHeads + 2) * kQKVDim, kModelDim, num_tokens); num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(), backward.pre_att_rms_out.data(),
grad.pre_attention_norm_scale.data(), grad.pre_attention_norm_scale.data(), backward.input.data(),
backward.input.data(), kModelDim, num_tokens); model_dim, num_tokens);
AddFromT(backward.attention_out.data(), backward.input.data(), AddFromT(backward.attention_out.data(), backward.input.data(),
num_tokens * kModelDim); num_tokens * model_dim);
} }
template <typename T> template <typename T>
@ -296,56 +294,54 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
} }
} }
template <typename T, typename TConfig> template <typename T>
void CrossEntropyLossBackwardPass(const Prompt& prompt, void CrossEntropyLossBackwardPass(const Prompt& prompt,
const CompressedWeights<TConfig>& weights, const ModelWeightsPtrs<T>& weights,
const ForwardPass<T, TConfig>& forward, const ForwardPass<T>& forward,
CompressedWeights<TConfig>& grad, ModelWeightsPtrs<T>& grad,
ForwardPass<T, TConfig>& backward) { ForwardPass<T>& backward) {
static constexpr size_t kModelDim = TConfig::kModelDim; const ModelConfig& config = weights.weights_config;
static constexpr size_t kVocabSize = TConfig::kVocabSize; const size_t model_dim = config.model_dim;
static constexpr size_t kLayers = TConfig::kLayers; const size_t vocab_size = config.vocab_size;
const size_t layers = config.layer_configs.size();
const std::vector<int> tokens = prompt.tokens; const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
kVocabSize); vocab_size);
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
kVocabSize, num_tokens); num_tokens);
if constexpr (TConfig::kFinalCap > 0.0f) { if (config.final_cap > 0.0f) {
for (size_t i = 0; i < num_tokens; ++i) { for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize, SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
backward.logits.data() + i * kVocabSize, kVocabSize); backward.logits.data() + i * vocab_size, vocab_size);
} }
} }
MatMulVJPT(weights.embedder_input_embedding.data(), MatMulVJPT(
forward.final_norm_output.data(), weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
backward.logits.data(), backward.logits.data(), grad.embedder_input_embedding.data(),
grad.embedder_input_embedding.data(), backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
backward.final_norm_output.data(),
kVocabSize, kModelDim, num_tokens);
RMSNormVJPT(weights.final_norm_scale.data(), RMSNormVJPT(weights.final_norm_scale.data(),
forward.final_layer_output.data(), forward.final_layer_output.data(),
backward.final_norm_output.data(), backward.final_norm_output.data(), grad.final_norm_scale.data(),
grad.final_norm_scale.data(), backward.final_layer_output.data(), model_dim, num_tokens);
backward.final_layer_output.data(), kModelDim, num_tokens);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) { for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
T* next_layer_grad = layer + 1 < kLayers T* next_layer_grad = layer + 1 < layers
? backward.layers[layer + 1].input.data() ? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data(); : backward.final_layer_output.data();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
*grad.GetLayer(layer), backward.layers[layer], num_tokens); *grad.GetLayer(layer), backward.layers[layer], num_tokens);
} }
const T kEmbScaling = EmbeddingScaling(kModelDim); const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
tokens, kEmbScaling, backward.layers[0].input.data(), kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), kModelDim); grad.embedder_input_embedding.data(), model_dim);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -19,7 +19,6 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> // memcpy #include <string.h> // memcpy
#include <array>
#include <complex> #include <complex>
#include <limits> #include <limits>
#include <random> #include <random>
@ -384,44 +383,49 @@ TEST(BackPropTest, InputEmbeddingVJP) {
} }
} }
template <typename T> static ModelConfig TestConfig() {
struct TestConfig : ConfigBaseGemmaV2 { ModelConfig config;
using Weight = T; config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
static constexpr int kSeqLen = 18; "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
static constexpr int kVocabSize = 12; config.model_dim = 32;
static constexpr int kModelDim = 32; config.vocab_size = 12;
static constexpr int kHeads = 3; config.seq_len = 18;
static constexpr int kQKVDim = 12; LayerConfig layer_config = {
static constexpr int kFFHiddenDim = 48; .model_dim = config.model_dim,
static constexpr std::array<LayerAttentionType, 2> kLayerConfig = .ff_hidden_dim = 48,
FixedLayerConfig<2>(LayerAttentionType::kGemma); .heads = 3,
static constexpr int kLayers = kLayerConfig.size(); .kv_heads = 1,
static constexpr int kNumTensorScales = 4 * kLayers; .qkv_dim = 12,
static constexpr bool kAbsolutePE = false; };
static constexpr PostNormType kPostNorm = PostNormType::None; config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
static constexpr int kKVHeads = 1; config.query_scale = QueryScaleType::SqrtKeySize;
static constexpr int kGemmaLayers = kLayers; config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
}; // This is required for optimize_test to pass.
config.final_cap = 30.0f;
return config;
}
TEST(BackPropTest, LayerVJP) { TEST(BackPropTest, LayerVJP) {
std::mt19937 gen(42); std::mt19937 gen(42);
using T = double; using T = double;
using TC = std::complex<T>; using TC = std::complex<T>;
const size_t kOutputSize = TestConfig<T>::kSeqLen * TestConfig<T>::kModelDim; ModelConfig config = TestConfig();
CompressedLayer<TestConfig<T>> weights; const size_t kOutputSize = config.seq_len * config.model_dim;
CompressedLayer<TestConfig<T>> grad; LayerWeightsPtrs<T> weights(config.layer_configs[0]);
ForwardLayer<T, TestConfig<T>> forward; LayerWeightsPtrs<T> grad(config.layer_configs[0]);
ForwardLayer<T, TestConfig<T>> backward = {}; ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
CompressedLayer<TestConfig<TC>> c_weights; ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
ForwardLayer<TC, TestConfig<TC>> c_forward; LayerWeightsPtrs<TC> c_weights(config.layer_configs[0]);
std::array<T, kOutputSize> y; ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
MatStorageT<T> y("y", kOutputSize, 1);
MatStorageT<T> dy("dy", kOutputSize, 1); MatStorageT<T> dy("dy", kOutputSize, 1);
std::array<TC, kOutputSize> c_y; MatStorageT<TC> c_y("c_y", kOutputSize, 1);
const size_t num_tokens = 3; const size_t num_tokens = 3;
weights.Allocate(); std::vector<MatStorage> layer_storage;
grad.Allocate(); weights.Allocate(layer_storage);
c_weights.Allocate(); grad.Allocate(layer_storage);
c_weights.Allocate(layer_storage);
backward.input.ZeroInit(); backward.input.ZeroInit();
for (size_t iter = 0; iter < 10; ++iter) { for (size_t iter = 0; iter < 10; ++iter) {
@ -432,7 +436,7 @@ TEST(BackPropTest, LayerVJP) {
Complexify(forward.input, c_forward.input); Complexify(forward.input, c_forward.input);
auto func = [&]() { auto func = [&]() {
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
return DotT(dy.data(), c_y.data(), num_tokens * TestConfig<T>::kModelDim); return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
}; };
grad.ZeroInit(/*layer_idx=*/0); grad.ZeroInit(/*layer_idx=*/0);
ApplyLayer(weights, forward, num_tokens, y.data()); ApplyLayer(weights, forward, num_tokens, y.data());
@ -447,12 +451,13 @@ TEST(BackPropTest, EndToEnd) {
std::mt19937 gen(42); std::mt19937 gen(42);
using T = double; using T = double;
using TC = std::complex<T>; using TC = std::complex<T>;
WeightsWrapper<TestConfig<T>> weights; ModelConfig config = TestConfig();
WeightsWrapper<TestConfig<T>> grad; WeightsWrapper<T> weights(config);
ForwardPass<T, TestConfig<T>> forward; WeightsWrapper<T> grad(config);
ForwardPass<T, TestConfig<T>> backward; ForwardPass<T> forward(config);
WeightsWrapper<TestConfig<TC>> c_weights; ForwardPass<T> backward(config);
ForwardPass<TC, TestConfig<TC>> c_forward; WeightsWrapper<TC> c_weights(config);
ForwardPass<TC> c_forward(config);
ReverseSequenceSampler training_task({0, 0, 1, 1}); ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen); std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
@ -474,9 +479,9 @@ TEST(BackPropTest, EndToEnd) {
} }
} }
template <typename T, typename TConfig> template <typename T>
void MulByConstAndAddT(T c, const CompressedLayer<TConfig>& x, void MulByConstAndAddT(T c, const LayerWeightsPtrs<T>& x,
CompressedLayer<TConfig>& out) { LayerWeightsPtrs<T>& out) {
MulByConstAndAddT(c, x.pre_attention_norm_scale, MulByConstAndAddT(c, x.pre_attention_norm_scale,
out.pre_attention_norm_scale); out.pre_attention_norm_scale);
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
@ -486,23 +491,23 @@ void MulByConstAndAddT(T c, const CompressedLayer<TConfig>& x,
MulByConstAndAddT(c, x.linear_w, out.linear_w); MulByConstAndAddT(c, x.linear_w, out.linear_w);
} }
template <typename T, typename TConfig> template <typename T>
void MulByConstAndAddT(T c, const CompressedWeights<TConfig>& x, void MulByConstAndAddT(T c, const ModelWeightsPtrs<T>& x,
CompressedWeights<TConfig>& out) { ModelWeightsPtrs<T>& out) {
static constexpr size_t kLayers = TConfig::kLayers; const size_t layers = x.c_layers.size();
MulByConstAndAddT(c, x.embedder_input_embedding, MulByConstAndAddT(c, x.embedder_input_embedding,
out.embedder_input_embedding); out.embedder_input_embedding);
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) { for (size_t i = 0; i < layers; ++i) {
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
} }
} }
// Evaluates forward pass on a batch. // Evaluates forward pass on a batch.
template <typename T, typename TConfig> template <typename T>
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch, T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
const WeightsWrapper<TConfig>& weights, const WeightsWrapper<T>& weights,
ForwardPass<T, TConfig>& forward) { ForwardPass<T>& forward) {
T loss = 0.0; T loss = 0.0;
for (const Prompt& prompt : batch) { for (const Prompt& prompt : batch) {
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
@ -514,12 +519,11 @@ T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
// Evaluates forward pass on a batch by applying gradient with the given // Evaluates forward pass on a batch by applying gradient with the given
// learning rate. Does not update weights, but uses the given tmp weights // learning rate. Does not update weights, but uses the given tmp weights
// instead. // instead.
template <typename T, typename TConfig> template <typename T>
T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch, T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
const WeightsWrapper<TConfig>& weights, const WeightsWrapper<T>& weights,
const WeightsWrapper<TConfig>& grad, const WeightsWrapper<T>& grad,
WeightsWrapper<TConfig>& tmp, WeightsWrapper<T>& tmp, ForwardPass<T>& forward) {
ForwardPass<T, TConfig>& forward) {
tmp.CopyFrom(weights); tmp.CopyFrom(weights);
const T scale = -learning_rate / batch.size(); const T scale = -learning_rate / batch.size();
MulByConstAndAddT(scale, grad.get(), tmp.get()); MulByConstAndAddT(scale, grad.get(), tmp.get());
@ -529,11 +533,9 @@ T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
// Uses line search in the negative gradient direction to update weights. We do // Uses line search in the negative gradient direction to update weights. We do
// this so that we can test that each step during the gradient descent can // this so that we can test that each step during the gradient descent can
// decrease the objective function value. // decrease the objective function value.
template <typename T, typename TConfig> template <typename T>
T FindOptimalUpdate(const WeightsWrapper<TConfig>& grad, T FindOptimalUpdate(const WeightsWrapper<T>& grad, WeightsWrapper<T>& weights,
WeightsWrapper<TConfig>& weights, WeightsWrapper<T>& tmp, ForwardPass<T>& forward,
WeightsWrapper<TConfig>& tmp,
ForwardPass<T, TConfig>& forward,
const std::vector<Prompt>& batch, T loss, const std::vector<Prompt>& batch, T loss,
T initial_learning_rate) { T initial_learning_rate) {
T lr0 = initial_learning_rate; T lr0 = initial_learning_rate;
@ -568,13 +570,14 @@ TEST(BackProptest, Convergence) {
std::mt19937 gen(42); std::mt19937 gen(42);
using T = float; using T = float;
using TC = std::complex<double>; using TC = std::complex<double>;
WeightsWrapper<TestConfig<T>> weights; ModelConfig config = TestConfig();
WeightsWrapper<TestConfig<T>> grad; WeightsWrapper<T> weights(config);
WeightsWrapper<TestConfig<T>> tmp; WeightsWrapper<T> grad(config);
ForwardPass<T, TestConfig<T>> forward; WeightsWrapper<T> tmp(config);
ForwardPass<T, TestConfig<T>> backward; ForwardPass<T> forward(config);
WeightsWrapper<TestConfig<TC>> c_weights; ForwardPass<T> backward(config);
ForwardPass<TC, TestConfig<TC>> c_forward; WeightsWrapper<TC> c_weights(config);
ForwardPass<TC> c_forward(config);
constexpr size_t kBatchSize = 5; constexpr size_t kBatchSize = 5;
ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
T learning_rate = 0.01; T learning_rate = 0.01;

View File

@ -19,7 +19,6 @@
#include <stddef.h> #include <stddef.h>
#include <array>
#include <complex> #include <complex>
#include <cstdlib> // std::abs #include <cstdlib> // std::abs
#include <random> #include <random>
@ -34,7 +33,6 @@
#include "backprop/test_util.h" #include "backprop/test_util.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/weights.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -50,6 +48,7 @@
#include "backprop/forward-inl.h" #include "backprop/forward-inl.h"
#include "compression/compress.h" #include "compression/compress.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
#include "util/allocator.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -85,7 +84,7 @@ void TestMatMulVJP() {
}; };
grad.ZeroInit(); grad.ZeroInit();
MatMulVJP<kCols, kRows>(weights.data(), x.data(), dy.data(), kTokens, MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
grad.data(), dx.data(), pool); grad.data(), dx.data(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
@ -130,9 +129,8 @@ void TestMultiHeadMatMulVJP() {
}; };
grad.ZeroInit(); grad.ZeroInit();
MultiHeadMatMulVJP<kHeads, kCols, kRows>( MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(), kRows, kTokens, grad.data(), dx.data(), pool);
pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
@ -186,63 +184,63 @@ void TestRMSNormVJP() {
} }
} }
template <typename T> static ModelConfig TestConfig() {
struct TestConfig : ConfigBaseGemmaV2 { ModelConfig config;
using Weight = T; config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
static constexpr int kSeqLen = 24; "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
static constexpr int kVocabSize = 16; config.model_dim = 32;
static constexpr int kModelDim = 32; config.vocab_size = 16;
static constexpr int kHeads = 3; config.seq_len = 24;
static constexpr int kQKVDim = 16; LayerConfig layer_config = {
static constexpr int kFFHiddenDim = 64; .model_dim = config.model_dim,
static constexpr std::array<LayerAttentionType, 2> kLayerConfig = .ff_hidden_dim = 64,
FixedLayerConfig<2>(LayerAttentionType::kGemma); .heads = 3,
static constexpr int kLayers = kLayerConfig.size(); .kv_heads = 1,
static constexpr int kNumTensorScales = 4 * kLayers; .qkv_dim = 16,
static constexpr bool kAbsolutePE = false; };
static constexpr PostNormType kPostNorm = PostNormType::None; config.layer_configs = {2, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
static constexpr int kKVHeads = 1; config.query_scale = QueryScaleType::SqrtKeySize;
static constexpr int kGemmaLayers = kLayers; config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
}; // This is required for optimize_test to pass.
config.att_cap = 50.0f;
config.final_cap = 30.0f;
return config;
}
void TestEndToEnd() { void TestEndToEnd() {
std::mt19937 gen(42); std::mt19937 gen(42);
hwy::ThreadPool pool(0); hwy::ThreadPool pool(0);
using WeightsF = CompressedWeights<TestConfig<float>>; ModelConfig config = TestConfig();
using LayerF = CompressedLayer<TestConfig<float>>; WeightsWrapper<float> weights(config);
WeightsWrapper<TestConfig<float>> weights; WeightsWrapper<float> grad(config);
WeightsWrapper<TestConfig<float>> grad; ForwardPass<float> forward0(config);
ActivationsWrapper<float, TestConfig<float>> forward0; ForwardPass<float> forward1(config);
ActivationsWrapper<float, TestConfig<float>> forward1; ForwardPass<float> backward(config);
ActivationsWrapper<float, TestConfig<float>> backward;
using TC = std::complex<double>; using TC = std::complex<double>;
WeightsWrapper<TestConfig<TC>> c_weights; WeightsWrapper<TC> c_weights(config);
ForwardPass<TC, TestConfig<TC>> c_forward; ForwardPass<TC> c_forward(config);
ReverseSequenceSampler training_task({0, 0, 1, 1}); ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen); std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = RowVectorBatch<float> inv_timescale = Activations::CreateInvTimescale(
Activations::CreateInvTimescale<TestConfig<float>>(); config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
for (const Prompt& prompt : batch) { for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt); ReverseSequenceSampler::LogPrompt(prompt);
RandInit(weights.get(), 1.0f, gen); RandInit(weights.get(), 1.0f, gen);
float loss0 = CrossEntropyLossForwardPass( float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
prompt, weights.get(), forward0.get());
float loss1 = float loss1 = CrossEntropyLossForwardPass(
CrossEntropyLossForwardPass<TestConfig<float>, WeightsF, LayerF>( prompt.tokens, prompt.context_size, weights.get(), forward1,
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
inv_timescale, pool); inv_timescale, pool);
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.ZeroInit(); grad.ZeroInit();
CrossEntropyLossBackwardPass<TestConfig<float>, WeightsF, LayerF>( CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
prompt, weights.get(), forward1.get(), grad.get(), backward.get(), backward, inv_timescale, pool);
inv_timescale, pool);
Complexify(weights.get(), c_weights.get()); Complexify(weights.get(), c_weights.get());
auto func = [&]() { auto func = [&]() {

View File

@ -26,6 +26,7 @@
#include "backprop/activations.h" #include "backprop/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -93,28 +94,28 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
return loss * scaling; return loss * scaling;
} }
template <typename TConfig, typename LayerT> template <typename T>
void ApplyForwardLayer(const LayerT& weights, void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<float, TConfig>& activations, ForwardLayer<float>& activations, size_t num_tokens,
size_t num_tokens, float* HWY_RESTRICT output, float* HWY_RESTRICT output,
const RowVectorBatch<float>& inv_timescale, const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim; const LayerConfig& config = weights.layer_config;
static constexpr size_t kSeqLen = TConfig::kSeqLen; const size_t model_dim = config.model_dim;
static constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t kSeqLen = activations.input.Rows();
static constexpr size_t kHeads = TConfig::kHeads; const size_t kQKVDim = config.qkv_dim;
static const float kQueryScale = const size_t kHeads = config.heads;
static const float query_scale =
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim))); static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen); HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.data(), ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
activations.input.data(), kModelDim, num_tokens, activations.input.data(), model_dim, num_tokens,
activations.pre_att_rms_out.data(), pool); activations.pre_att_rms_out.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<(kHeads + 2) * kQKVDim, kModelDim>( MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
weights.qkv_einsum_w, 0, activations.pre_att_rms_out.data() + pos * model_dim,
activations.pre_att_rms_out.data() + pos * kModelDim,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
} }
const size_t num_tasks = kHeads * num_tokens; const size_t num_tasks = kHeads * num_tokens;
@ -130,7 +131,7 @@ void ApplyForwardLayer(const LayerT& weights,
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, inv_timescale.Const(), pos); Rope(q, kQKVDim, inv_timescale.Const(), pos);
MulByConst(kQueryScale, q, kQKVDim); MulByConst(query_scale, q, kQKVDim);
}); });
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
@ -174,28 +175,28 @@ void ApplyForwardLayer(const LayerT& weights,
activations.attention_out.ZeroInit(); activations.attention_out.ZeroInit();
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < kHeads; ++head) {
MatVec<kModelDim, kQKVDim>( MatVec(
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
activations.att_post1.data() + pos * kModelDim, pool); activations.att_post1.data() + pos * model_dim, pool);
AddFrom(activations.att_post1.data() + pos * kModelDim, AddFrom(activations.att_post1.data() + pos * model_dim,
activations.attention_out.data() + pos * kModelDim, kModelDim); activations.attention_out.data() + pos * model_dim, model_dim);
} }
} }
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.data() + pos * kModelDim, AddFrom(activations.input.data() + pos * model_dim,
activations.attention_out.data() + pos * kModelDim, kModelDim); activations.attention_out.data() + pos * model_dim, model_dim);
} }
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
activations.attention_out.data(), kModelDim, num_tokens, activations.attention_out.data(), model_dim, num_tokens,
activations.bf_pre_ffw_rms_out.data(), pool); activations.bf_pre_ffw_rms_out.data(), pool);
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t kFFHiddenDim = config.ff_hidden_dim;
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kFFHiddenDim * 2, kModelDim>( MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
weights.gating_einsum_w, 0, activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
} }
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -215,77 +216,76 @@ void ApplyForwardLayer(const LayerT& weights,
} }
} }
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kModelDim, kFFHiddenDim>( MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
weights.linear_w, 0,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
output + pos * kModelDim, pool); output + pos * model_dim, pool);
} }
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * kModelDim, AddFrom(activations.attention_out.data() + pos * model_dim,
output + pos * kModelDim, kModelDim); output + pos * model_dim, model_dim);
} }
} }
template <typename TConfig, typename WeightsT, typename LayerT> template <typename T>
float CrossEntropyLossForwardPass(const std::vector<int>& prompt, float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size, const WeightsT& weights, size_t context_size,
ForwardPass<float, TConfig>& forward, const ModelWeightsPtrs<T>& weights,
ForwardPass<float>& forward,
const RowVectorBatch<float>& inv_timescale, const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
static constexpr size_t kVocabSize = TConfig::kVocabSize; const ModelConfig& config = weights.weights_config;
static constexpr size_t kModelDim = TConfig::kModelDim; const size_t vocab_size = config.vocab_size;
static constexpr size_t kLayers = TConfig::kLayers; const size_t model_dim = config.model_dim;
const float kEmbScaling = EmbeddingScaling<TConfig>(); const size_t layers = config.layer_configs.size();
static_assert(!TConfig::kAbsolutePE); const float emb_scaling = EmbeddingScaling(model_dim);
static_assert(TConfig::kPostNorm == PostNormType::None); HWY_ASSERT(!config.absolute_pe);
static_assert(TConfig::kKVHeads == 1); HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
HWY_DASSERT(context_size > 0); HWY_DASSERT(context_size > 0);
HWY_DASSERT(context_size < prompt.size()); HWY_DASSERT(context_size < prompt.size());
const size_t num_tokens = prompt.size() - 1; const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
forward.layers[0].input.data(), kModelDim, kVocabSize); forward.layers[0].input.data(), model_dim, vocab_size);
for (size_t layer = 0; layer < kLayers; ++layer) { for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
auto type = TConfig::kLayerConfig[layer]; auto type = config.layer_configs[layer].type;
// TODO(szabadka) Implement Griffin layer. // TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma); HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < kLayers ? float* HWY_RESTRICT output = layer + 1 < layers
forward.layers[layer + 1].input.data() : ? forward.layers[layer + 1].input.data()
forward.final_layer_output.data(); : forward.final_layer_output.data();
ApplyForwardLayer<TConfig, LayerT>(*weights.GetLayer(layer), ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
forward.layers[layer], num_tokens, num_tokens, output, inv_timescale, pool);
output, inv_timescale, pool);
} }
ApplyRMSNorm(weights.final_norm_scale.data(), ApplyRMSNorm(weights.final_norm_scale.data(),
forward.final_layer_output.data(), forward.final_layer_output.data(), model_dim, num_tokens,
kModelDim, num_tokens, forward.final_norm_output.data(), pool); forward.final_norm_output.data(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec<kVocabSize, kModelDim>( MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
weights.embedder_input_embedding, 0, forward.final_norm_output.data() + pos * model_dim,
forward.final_norm_output.data() + pos * kModelDim, forward.logits.data() + pos * vocab_size, pool);
forward.logits.data() + pos * kVocabSize, pool);
} }
if constexpr (TConfig::kFinalCap > 0.0f) { if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(TConfig::kFinalCap, LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
forward.logits.data() + pos * kVocabSize, kVocabSize); vocab_size);
} }
} }
hwy::CopyBytes(forward.logits.data(), forward.probs.data(), hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
num_tokens * kVocabSize * sizeof(forward.logits.At(0))); num_tokens * vocab_size * sizeof(forward.logits.At(0)));
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
} }
return CrossEntropyLoss(forward.probs.data(), prompt, context_size, return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
kVocabSize, pool); vocab_size, pool);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -17,8 +17,9 @@
#include "backprop/activations.h" #include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
@ -36,38 +37,13 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
template <typename TConfig> float CrossEntropyLossForwardPassT(const Prompt& prompt,
float CrossEntropyLossForwardPass(const Prompt& prompt, const ModelWeightsPtrs<float>& weights,
const ByteStorageT& weights_u8, ForwardPass<float>& forward,
ByteStorageT& forward_u8,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
const auto& weights = return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get()); weights, forward, inv_timescale, pool);
auto& forward =
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
return CrossEntropyLossForwardPass<TConfig, CompressedWeights<TConfig>,
CompressedLayer<TConfig>>(
prompt.tokens, prompt.context_size, weights, forward, inv_timescale,
pool);
}
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(
prompt, weights, forward, inv_timescale, pool);
case Model::GEMMA_TINY:
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
prompt, weights, forward, inv_timescale, pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
@ -79,13 +55,13 @@ namespace gcpp {
HWY_EXPORT(CrossEntropyLossForwardPassT); HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, float CrossEntropyLossForwardPass(const Prompt& prompt,
const ByteStorageT& weights, const ModelWeightsPtrs<float>& weights,
ByteStorageT& forward, ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
model, prompt, weights, forward, inv_timescale, pool); prompt, weights, forward, inv_timescale, pool);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -16,16 +16,17 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
#include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/weights.h"
#include "gemma/common.h" #include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, float CrossEntropyLossForwardPass(const Prompt& prompt,
const ByteStorageT& weights, const ModelWeightsPtrs<float>& weights,
ByteStorageT& forward, ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale, RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);

View File

@ -127,108 +127,107 @@ void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
} }
} }
template<typename T> template <typename T>
void MaskedAttention(const T* qkv, T* output, size_t num_tokens, void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads,
size_t kHeads, size_t kQKVDim, size_t kSeqLen) { size_t qkv_dim, size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
const size_t qoffset = pos * (kHeads + 2) * kQKVDim; const size_t qoffset = pos * (heads + 2) * qkv_dim;
const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen; const size_t aoffset = pos * heads * seq_len + head * seq_len;
const T* q = qkv + qoffset + head * kQKVDim; const T* q = qkv + qoffset + head * qkv_dim;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim;
output[aoffset + pos2] = DotT(q, k, kQKVDim); output[aoffset + pos2] = DotT(q, k, qkv_dim);
} }
} }
} }
} }
template<typename T> template <typename T>
void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) { void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; size_t offset = pos * heads * seq_len + head * seq_len;
Softmax(x + offset, pos + 1); Softmax(x + offset, pos + 1);
memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
} }
} }
} }
template<typename T> template <typename T>
void MixByAttention(const T* qkv, const T* attention, T* output, void MixByAttention(const T* qkv, const T* attention, T* output,
size_t num_tokens, size_t kHeads, size_t kQKVDim, size_t num_tokens, size_t heads, size_t qkv_dim,
size_t kSeqLen) { size_t seq_len) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) { for (size_t head = 0; head < heads; ++head) {
const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen]; const T* att = &attention[pos * heads * seq_len + head * seq_len];
T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim]; T* out = &output[head * qkv_dim + pos * heads * qkv_dim];
memset(out, 0, kQKVDim * sizeof(out[0])); memset(out, 0, qkv_dim * sizeof(out[0]));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) { for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const T* v = &qkv[v_offset]; const T* v = &qkv[v_offset];
MulByConstAndAddT(att[pos2], v, out, kQKVDim); MulByConstAndAddT(att[pos2], v, out, qkv_dim);
} }
} }
} }
} }
template <typename T, typename TConfig> template <typename T>
void ApplyLayer(const CompressedLayer<TConfig>& weights, void ApplyLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<T, TConfig>& activations, size_t num_tokens, ForwardLayer<T>& activations, size_t num_tokens, T* output) {
T* output) { const LayerConfig& layer_config = weights.layer_config;
static constexpr size_t kModelDim = TConfig::kModelDim; const size_t model_dim = layer_config.model_dim;
static constexpr size_t kSeqLen = TConfig::kSeqLen; const size_t seq_len = activations.input.Rows();
static constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t qkv_dim = layer_config.qkv_dim;
static constexpr size_t kHeads = TConfig::kHeads; const size_t heads = layer_config.heads;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim)); static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
activations.pre_att_rms_out.data(), kModelDim, num_tokens); activations.pre_att_rms_out.data(), model_dim, num_tokens);
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim, activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
for (size_t h = 0; h <= kHeads; ++h) { for (size_t h = 0; h <= heads; ++h) {
Rope(qkv + h * kQKVDim, kQKVDim, pos); Rope(qkv + h * qkv_dim, qkv_dim, pos);
} }
} }
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); MulByConstT(query_scale, qkv, heads * qkv_dim);
} }
MaskedAttention(activations.qkv.data(), activations.att.data(), MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
num_tokens, kHeads, kQKVDim, kSeqLen); heads, qkv_dim, seq_len);
MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen); MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
MixByAttention(activations.qkv.data(), activations.att.data(), MixByAttention(activations.qkv.data(), activations.att.data(),
activations.att_out.data(), num_tokens, kHeads, kQKVDim, activations.att_out.data(), num_tokens, heads, qkv_dim,
kSeqLen); seq_len);
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
activations.attention_out.data(), kHeads, kModelDim, kQKVDim, activations.attention_out.data(), heads, model_dim, qkv_dim,
num_tokens); num_tokens);
AddFromT(activations.input.data(), activations.attention_out.data(), AddFromT(activations.input.data(), activations.attention_out.data(),
num_tokens * kModelDim); num_tokens * model_dim);
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens); activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim, activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
num_tokens); num_tokens);
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
kFFHiddenDim, num_tokens); ff_hidden_dim, num_tokens);
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
output, kModelDim, kFFHiddenDim, num_tokens); model_dim, ff_hidden_dim, num_tokens);
AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim); AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
} }
template<typename T> template<typename T>
@ -247,48 +246,47 @@ T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
return loss * scaling; return loss * scaling;
} }
template <typename T, typename TConfig> template <typename T>
T CrossEntropyLossForwardPass(const Prompt& prompt, T CrossEntropyLossForwardPass(const Prompt& prompt,
const CompressedWeights<TConfig>& weights, const ModelWeightsPtrs<T>& weights,
ForwardPass<T, TConfig>& forward) { ForwardPass<T>& forward) {
static constexpr size_t kModelDim = TConfig::kModelDim; const ModelConfig& config = weights.weights_config;
static constexpr size_t kVocabSize = TConfig::kVocabSize; const size_t model_dim = config.model_dim;
static constexpr size_t kLayers = TConfig::kLayers; const size_t vocab_size = config.vocab_size;
const size_t layers = config.layer_configs.size();
const std::vector<int> tokens = prompt.tokens; const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(kModelDim); const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbedding(weights.embedder_input_embedding.data(), tokens, InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
kEmbScaling, forward.layers[0].input.data(), kModelDim); forward.layers[0].input.data(), model_dim);
for (size_t layer = 0; layer < kLayers; ++layer) { for (size_t layer = 0; layer < layers; ++layer) {
T* output = layer + 1 < kLayers ? T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
forward.layers[layer + 1].input.data() : : forward.final_layer_output.data();
forward.final_layer_output.data();
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
output); output);
} }
RMSNormT(weights.final_norm_scale.data(), RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
forward.final_layer_output.data(), forward.final_norm_output.data(), model_dim, num_tokens);
forward.final_norm_output.data(), kModelDim, num_tokens);
MatMulT(weights.embedder_input_embedding.data(), MatMulT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(), forward.final_norm_output.data(), forward.logits.data(), vocab_size,
forward.logits.data(), kVocabSize, kModelDim, num_tokens); model_dim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
if constexpr (TConfig::kFinalCap > 0.0f) { if (config.final_cap > 0.0f) {
Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
kVocabSize); vocab_size);
} }
} }
memcpy(forward.probs.data(), forward.logits.data(), memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * kVocabSize * sizeof(forward.logits.At(0))); num_tokens * vocab_size * sizeof(forward.logits.At(0)));
Softmax(forward.probs.data(), kVocabSize, num_tokens); Softmax(forward.probs.data(), vocab_size, num_tokens);
return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize); return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -16,6 +16,7 @@
#include <stddef.h> #include <stddef.h>
#include <algorithm> #include <algorithm>
#include <cstdio>
#include <random> #include <random>
#include <vector> #include <vector>
@ -26,8 +27,10 @@
#include "backprop/optimizer.h" #include "backprop/optimizer.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "backprop/sampler.h" #include "backprop/sampler.h"
#include "compression/shared.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/threading.h" #include "util/threading.h"
@ -45,20 +48,18 @@ TEST(OptimizeTest, GradientDescent) {
.training = ModelTraining::GEMMA_IT, .training = ModelTraining::GEMMA_IT,
.weight = Type::kF32, .weight = Type::kF32,
}; };
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>( ModelConfig config = ConfigFromModel(info.model);
info.model, info.weight, pool); ModelWeightsStorage grad, grad_m, grad_v;
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>( grad.Allocate(info.model, info.weight, pool);
info.model, info.weight, pool); grad_m.Allocate(info.model, info.weight, pool);
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>( grad_v.Allocate(info.model, info.weight, pool);
info.model, info.weight, pool); grad_m.ZeroInit();
ByteStorageT forward = grad_v.ZeroInit();
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight); ForwardPass<float> forward(config), backward(config);
ByteStorageT backward = KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
RowVectorBatch<float> inv_timescale = RowVectorBatch<float> inv_timescale = Activations::CreateInvTimescale(
Activations::CreateInvTimescale<ConfigGemmaTiny<float>>(); config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
Gemma gemma(GemmaTokenizer(), info, pools); Gemma gemma(GemmaTokenizer(), info, pools);
@ -92,14 +93,11 @@ TEST(OptimizeTest, GradientDescent) {
reply.begin() + context.size()); reply.begin() + context.size());
}; };
RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen); gemma.MutableWeights().RandInit(gen);
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight, gemma.MutableWeights().AllocAndCopyWithTranspose(pool);
grad_m, pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
grad_v, pool);
printf("Initial weights:\n"); printf("Initial weights:\n");
LogWeightStats(info.model, info.weight, gemma.Weights()); gemma.MutableWeights().LogWeightStats();
constexpr size_t kBatchSize = 8; constexpr size_t kBatchSize = 8;
const float alpha = 0.001f; const float alpha = 0.001f;
@ -113,29 +111,29 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok; size_t num_ok;
for (; steps < 1000000; ++steps) { for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42); std::mt19937 sgen(42);
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight, grad.ZeroInit();
grad, pool);
float total_loss = 0.0f; float total_loss = 0.0f;
num_ok = 0; num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) { for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen); Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass( total_loss += CrossEntropyLossForwardPass(
info.model, prompt, gemma.Weights(), forward, inv_timescale, pool); prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward, inv_timescale, pool);
grad, backward, inv_timescale, pool); CrossEntropyLossBackwardPass(
CallForModelAndWeight<ReshapeCompressedWeights>( prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
info.model, info.weight, gemma.MutableWeights(), pool); *grad.GetWeightsOfType<float>(), backward, inv_timescale, pool);
gemma.MutableWeights().CopyWithTranspose(pool);
num_ok += verify(prompt) ? 1 : 0; num_ok += verify(prompt) ? 1 : 0;
} }
total_loss /= kBatchSize; total_loss /= kBatchSize;
AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon, AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1,
steps + 1, gemma.Weights(), grad_m, grad_v, pool); gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize); steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) { if (steps % 100 == 0) {
printf("Batch gradient:\n"); printf("Batch gradient:\n");
LogWeightStats(info.model, info.weight, grad); grad.LogWeightStats();
} }
if (total_loss < 0.5f) { if (total_loss < 0.5f) {
break; break;
@ -143,7 +141,7 @@ TEST(OptimizeTest, GradientDescent) {
} }
printf("Num steps: %zu\n", steps); printf("Num steps: %zu\n", steps);
printf("Final weights:\n"); printf("Final weights:\n");
LogWeightStats(info.model, info.weight, gemma.Weights()); gemma.MutableWeights().LogWeightStats();
EXPECT_LT(steps, 300); EXPECT_LT(steps, 300);
EXPECT_EQ(num_ok, kBatchSize); EXPECT_EQ(num_ok, kBatchSize);
} }

View File

@ -16,7 +16,6 @@
#include "backprop/optimizer.h" #include "backprop/optimizer.h"
#include <cmath> #include <cmath>
#include <random>
#include "compression/compress.h" #include "compression/compress.h"
#include "gemma/common.h" #include "gemma/common.h"
@ -30,37 +29,6 @@ namespace gcpp {
namespace { namespace {
class WeightInitializer {
public:
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
float* data = tensors[0]->data<float>();
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
data[i] = dist_(gen_);
}
tensors[0]->set_scale(1.0f);
}
private:
std::normal_distribution<float> dist_;
std::mt19937& gen_;
};
template <typename TConfig>
struct RandInitWeightsT {
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) const {
auto& weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python
// version.
WeightInitializer init(gen);
CompressedWeights<TConfig>::ForEachTensor({&weights},
ForEachType::kLoadNoToc, init);
}
};
class AdamUpdater { class AdamUpdater {
public: public:
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon, explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
@ -97,42 +65,31 @@ class AdamUpdater {
float epsilon_; float epsilon_;
}; };
template <typename TConfig> void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
struct AdamUpdateT {
void operator()(const ByteStorageT& grad_u8, float alpha, float beta1,
float beta2, float epsilon, size_t t, float beta2, float epsilon, size_t t,
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, ModelWeightsPtrs<float>* weights,
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { ModelWeightsPtrs<float>* grad_m,
using TWeights = CompressedWeights<TConfig>; ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
auto& weights = *reinterpret_cast<TWeights*>(weights_u8.get());
auto& grad_m = *reinterpret_cast<TWeights*>(grad_m_u8.get());
auto& grad_v = *reinterpret_cast<TWeights*>(grad_v_u8.get());
AdamUpdater updater(alpha, beta1, beta2, epsilon, t); AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
TWeights::ForEachTensor( ModelWeightsPtrs<float>::ForEachTensor(
{&grad, &weights, &grad_m, &grad_v}, ForEachType::kLoadNoToc, {grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc,
[&updater](const char* name, hwy::Span<MatPtr*> tensors) { [&updater](const char* name, hwy::Span<MatPtr*> tensors) {
updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]); updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]);
}); });
} }
};
} // namespace } // namespace
void RandInitWeights(Model model_type, Type weight_type, void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
const ByteStorageT& weights, hwy::ThreadPool& pool, float beta1, float beta2, float epsilon, size_t t,
std::mt19937& gen) { const ModelWeightsStorage& weights,
const ModelWeightsStorage& grad_m,
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) {
HWY_ASSERT(weight_type == Type::kF32); HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, RandInitWeightsT>(model_type, weights, pool, gen); AdamUpdate(grad.GetWeightsOfType<float>(), alpha, beta1, beta2, epsilon, t,
} weights.GetWeightsOfType<float>(),
grad_m.GetWeightsOfType<float>(), grad_v.GetWeightsOfType<float>(),
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, pool);
float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, AdamUpdateT>(model_type, grad, alpha, beta1, beta2,
epsilon, t, weights, grad_m, grad_v, pool);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -16,22 +16,17 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include <random>
#include "gemma/common.h" #include "gemma/common.h"
#include "util/allocator.h" #include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
void RandInitWeights(Model model_type, Type weight_type, void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
const ByteStorageT& weights, hwy::ThreadPool& pool, float beta1, float beta2, float epsilon, size_t t,
std::mt19937& gen); const ModelWeightsStorage& weights,
const ModelWeightsStorage& grad_m,
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool);
float alpha, float beta1, float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
} // namespace gcpp } // namespace gcpp

View File

@ -21,11 +21,12 @@
#include <cmath> #include <cmath>
#include <complex> #include <complex>
#include <random> #include <random>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "compression/compress.h" #include "compression/compress.h"
#include "gemma/configs.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -39,8 +40,8 @@ void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
} }
// TODO: make a member of Layer<T>. // TODO: make a member of Layer<T>.
template <typename T, typename TConfig> template <typename T>
void RandInit(CompressedLayer<TConfig>& w, T stddev, std::mt19937& gen) { void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
RandInit(w.pre_attention_norm_scale, stddev, gen); RandInit(w.pre_attention_norm_scale, stddev, gen);
RandInit(w.attn_vec_einsum_w, stddev, gen); RandInit(w.attn_vec_einsum_w, stddev, gen);
RandInit(w.qkv_einsum_w, stddev, gen); RandInit(w.qkv_einsum_w, stddev, gen);
@ -49,9 +50,9 @@ void RandInit(CompressedLayer<TConfig>& w, T stddev, std::mt19937& gen) {
RandInit(w.linear_w, stddev, gen); RandInit(w.linear_w, stddev, gen);
} }
template <typename T, typename TConfig> template <typename T>
void RandInit(CompressedWeights<TConfig>& w, T stddev, std::mt19937& gen) { void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
static constexpr size_t kLayers = TConfig::kLayers; const size_t kLayers = w.c_layers.size();
RandInit(w.embedder_input_embedding, stddev, gen); RandInit(w.embedder_input_embedding, stddev, gen);
RandInit(w.final_norm_scale, stddev, gen); RandInit(w.final_norm_scale, stddev, gen);
for (size_t i = 0; i < kLayers; ++i) { for (size_t i = 0; i < kLayers; ++i) {
@ -66,9 +67,8 @@ void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
} }
} }
template <typename TConfig, typename UConfig> template <typename T, typename U>
void Complexify(const CompressedLayer<TConfig>& w, void Complexify(const LayerWeightsPtrs<T>& w, LayerWeightsPtrs<U>& c_w) {
CompressedLayer<UConfig>& c_w) {
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
@ -77,10 +77,9 @@ void Complexify(const CompressedLayer<TConfig>& w,
Complexify(w.linear_w, c_w.linear_w); Complexify(w.linear_w, c_w.linear_w);
} }
template <typename TConfig, typename UConfig> template <typename T, typename U>
void Complexify(const CompressedWeights<TConfig>& w, void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
CompressedWeights<UConfig>& c_w) { const size_t kLayers = w.c_layers.size();
static constexpr size_t kLayers = TConfig::kLayers;
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
Complexify(w.final_norm_scale, c_w.final_norm_scale); Complexify(w.final_norm_scale, c_w.final_norm_scale);
for (size_t i = 0; i < kLayers; ++i) { for (size_t i = 0; i < kLayers; ++i) {
@ -88,26 +87,27 @@ void Complexify(const CompressedWeights<TConfig>& w,
} }
} }
// Owns weights and provides access to TConfig. // Somewhat duplicates ModelWeightsStorage, but that has neither double nor
template <typename TConfig> // complex types allowed and it would cause code bloat to add them there.
template <typename T>
class WeightsWrapper { class WeightsWrapper {
public: public:
WeightsWrapper() explicit WeightsWrapper(const ModelConfig& config)
: pool_(0), : pool_(0), weights_(config, pool_) {
data_(AllocateCompressedWeights<TConfig>()(pool_)), weights_.Allocate(data_, pool_);
weights_(reinterpret_cast<CompressedWeights<TConfig>*>(data_.get())) {} }
const CompressedWeights<TConfig>& get() const { return *weights_; } const ModelWeightsPtrs<T>& get() const { return weights_; }
CompressedWeights<TConfig>& get() { return *weights_; } ModelWeightsPtrs<T>& get() { return weights_; }
void ZeroInit() { weights_->ZeroInit(); } void ZeroInit() { weights_.ZeroInit(); }
void CopyFrom(const WeightsWrapper<TConfig>& other) { void CopyFrom(const WeightsWrapper<T>& other) {
get().CopyFrom(other.get()); weights_.CopyFrom(other.weights_);
} }
private: private:
hwy::ThreadPool pool_; hwy::ThreadPool pool_;
ByteStorageT data_; std::vector<MatStorage> data_;
CompressedWeights<TConfig>* weights_; ModelWeightsPtrs<T> weights_;
}; };
template <typename T, typename U> template <typename T, typename U>
@ -173,9 +173,9 @@ void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
} }
template <typename T, typename TConfig, typename UConfig, typename FUNC> template <typename T, typename U, typename FUNC>
void TestGradient(const CompressedLayer<TConfig>& grad, void TestGradient(const LayerWeightsPtrs<T>& grad,
CompressedLayer<UConfig>& c_weights, FUNC func, T max_err) { LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
TestGradient(grad.pre_attention_norm_scale, TestGradient(grad.pre_attention_norm_scale,
c_weights.pre_attention_norm_scale, c_weights.pre_attention_norm_scale,
func, max_err, max_err, __LINE__); func, max_err, max_err, __LINE__);
@ -191,15 +191,15 @@ void TestGradient(const CompressedLayer<TConfig>& grad,
func, max_err, max_err, __LINE__); func, max_err, max_err, __LINE__);
} }
template <typename T, typename TConfig, typename UConfig, typename FUNC> template <typename T, typename U, typename FUNC>
void TestGradient(const CompressedWeights<TConfig>& grad, void TestGradient(const ModelWeightsPtrs<T>& grad,
CompressedWeights<UConfig>& c_weights, FUNC func, T max_err) { ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
TestGradient(grad.embedder_input_embedding, TestGradient(grad.embedder_input_embedding,
c_weights.embedder_input_embedding, c_weights.embedder_input_embedding,
func, 2 * max_err, max_err, __LINE__); func, 2 * max_err, max_err, __LINE__);
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
func, max_err, max_err, __LINE__); func, max_err, max_err, __LINE__);
for (int i = 0; i < TConfig::kLayers; ++i) { for (size_t i = 0; i < grad.c_layers.size(); ++i) {
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err); TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
} }
} }

View File

@ -21,7 +21,6 @@
#include <atomic> #include <atomic>
#include <cstdio> #include <cstdio>
#include <memory> #include <memory>
#include <new>
#include <string> #include <string>
#include <vector> #include <vector>
@ -276,6 +275,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) { [pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
if (!pfile->Read(requests[i].offset, requests[i].size, if (!pfile->Read(requests[i].offset, requests[i].size,
requests[i].data)) { requests[i].data)) {
fprintf(stderr, "Failed to read blob %zu\n", i);
err.test_and_set(); err.test_and_set();
} }
}); });

View File

@ -102,8 +102,8 @@ class CompressedArray {
class MatPtr { class MatPtr {
public: public:
// Full constructor for dynamic sizing. // Full constructor for dynamic sizing.
MatPtr(const std::string& name, const std::string& type, size_t element_size, MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
size_t rows, size_t cols) size_t cols)
: name_(name), : name_(name),
type_(type), type_(type),
element_size_(element_size), element_size_(element_size),
@ -129,7 +129,7 @@ class MatPtr {
MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1, MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1,
const hwy::uint128_t& key2, const hwy::uint128_t& key3) const hwy::uint128_t& key2, const hwy::uint128_t& key3)
: name_(StringFromKey(key0)), : name_(StringFromKey(key0)),
type_(StringFromKey(key1)), type_(static_cast<Type>(key1.lo)),
element_size_(key2.hi), element_size_(key2.hi),
num_elements_(key2.lo), num_elements_(key2.lo),
rows_(key3.lo), rows_(key3.lo),
@ -138,7 +138,7 @@ class MatPtr {
// Adds the contents entry to the table of contents. // Adds the contents entry to the table of contents.
void AddToToc(std::vector<hwy::uint128_t>& toc) const { void AddToToc(std::vector<hwy::uint128_t>& toc) const {
toc.push_back(MakeKey(name_.c_str())); toc.push_back(MakeKey(name_.c_str()));
toc.push_back(MakeKey(type_.c_str())); toc.push_back({static_cast<uint64_t>(type_), 0});
toc.push_back({num_elements_, element_size_}); toc.push_back({num_elements_, element_size_});
toc.push_back({rows_, cols_}); toc.push_back({rows_, cols_});
} }
@ -167,7 +167,7 @@ class MatPtr {
void SetName(const std::string& name) { name_ = name; } void SetName(const std::string& name) { name_ = name; }
// Returns the type of the blob. // Returns the type of the blob.
const std::string& Type() const { return type_; } Type GetType() const { return type_; }
// Returns the size of each element in bytes. // Returns the size of each element in bytes.
size_t ElementSize() const { return element_size_; } size_t ElementSize() const { return element_size_; }
@ -219,8 +219,8 @@ class MatPtr {
protected: protected:
// Arbitrary name for the array of preferably <= 16 characters. // Arbitrary name for the array of preferably <= 16 characters.
std::string name_; std::string name_;
// Should be the result of TypeName<T> for CallUpcasted() to work. // Should be the result of TypeEnum<T> for CallUpcasted() to work.
std::string type_; Type type_;
// sizeof(T) // sizeof(T)
size_t element_size_ = 0; size_t element_size_ = 0;
// Number of elements in the array. // Number of elements in the array.
@ -247,7 +247,7 @@ class MatPtrT : public MatPtr {
// Full constructor for dynamic sizing. // Full constructor for dynamic sizing.
MatPtrT(const std::string& name, size_t rows, size_t cols) MatPtrT(const std::string& name, size_t rows, size_t cols)
: MatPtr(name, TypeName<MatT>(), sizeof(MatT), rows, cols) {} : MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
// Copying allowed as the metadata is small. // Copying allowed as the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {} MatPtrT(const MatPtr& other) : MatPtr(other) {}
@ -330,17 +330,20 @@ class MatPtrT : public MatPtr {
template <class FuncT, typename... TArgs> template <class FuncT, typename... TArgs>
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
if (type_ == TypeName<float>()) { if (type_ == TypeEnum<float>()) {
return func(dynamic_cast<MatPtrT<float>*>(this), return func(dynamic_cast<MatPtrT<float>*>(this),
std::forward<TArgs>(args)...); std::forward<TArgs>(args)...);
} else if (type_ == TypeName<BF16>()) { } else if (type_ == TypeEnum<BF16>()) {
return func(dynamic_cast<MatPtrT<BF16>*>(this), return func(dynamic_cast<MatPtrT<BF16>*>(this),
std::forward<TArgs>(args)...); std::forward<TArgs>(args)...);
} else if (type_ == TypeName<SfpStream>()) { } else if (type_ == TypeEnum<SfpStream>()) {
return func(dynamic_cast<MatPtrT<SfpStream>*>(this), return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
std::forward<TArgs>(args)...); std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<NuqStream>()) {
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
std::forward<TArgs>(args)...);
} else { } else {
HWY_ABORT("Type %s unknown.", type_.c_str()); HWY_ABORT("Type %d unknown.", type_);
} }
} }
@ -563,9 +566,10 @@ class CacheLoader {
} }
// Returns whether all tensors are successfully loaded from cache. // Returns whether all tensors are successfully loaded from cache.
bool ReadAll(hwy::ThreadPool& pool, std::vector<MatStorage>& model_memory) { BlobError ReadAll(hwy::ThreadPool& pool,
std::vector<MatStorage>& model_memory) {
// reader_ invalid or any Enqueue failed // reader_ invalid or any Enqueue failed
if (err_ != 0) return false; if (err_ != 0) return err_;
// Setup the model_memory. // Setup the model_memory.
for (int b = 0; b < model_toc_.size(); ++b) { for (int b = 0; b < model_toc_.size(); ++b) {
const std::string& file_key = file_keys_[b]; const std::string& file_key = file_keys_[b];
@ -574,12 +578,12 @@ class CacheLoader {
const MatPtr* toc_blob = file_toc_.Get(file_key); const MatPtr* toc_blob = file_toc_.Get(file_key);
if (toc_blob == nullptr) { if (toc_blob == nullptr) {
fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str()); fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str());
return false; return __LINE__;
} }
if (toc_blob->Rows() != blob->Rows() || if (toc_blob->Rows() != blob->Rows() ||
toc_blob->Cols() != blob->Cols()) { toc_blob->Cols() != blob->Cols()) {
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str()); fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
return false; return __LINE__;
} }
MatStorage toc_blob_array(*toc_blob); MatStorage toc_blob_array(*toc_blob);
model_memory.push_back(std::move(toc_blob_array)); model_memory.push_back(std::move(toc_blob_array));
@ -603,17 +607,10 @@ class CacheLoader {
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
blob.Name().c_str(), err_, blob.Rows(), blob.Cols(), blob.Name().c_str(), err_, blob.Rows(), blob.Cols(),
blob.ElementSize()); blob.ElementSize());
return false; return err_;
} }
} }
return reader_.ReadAll(pool);
err_ = reader_.ReadAll(pool);
if (err_ != 0) {
fprintf(stderr, "Failed to read all tensors (error %d)\n", err_);
return false;
}
return true;
} }
private: private:

View File

@ -24,6 +24,7 @@
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "gemma/configs.h"
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE #ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
#define GEMMA_COMPRESS_WEIGHTS_ONCE #define GEMMA_COMPRESS_WEIGHTS_ONCE
@ -150,29 +151,22 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
template <class Configs> template <typename T>
void CompressWeights(const Path& weights_path, void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type, const Path& compressed_weights_path, Model model_type,
Type weight_type, hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
if (!weights_path.Exists()) { if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.", HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str()); weights_path.path.c_str());
} }
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
compressed_weights_path.path.c_str()); compressed_weights_path.path.c_str());
ModelConfig config = ConfigFromModel(model_type);
using CConfig = typename Configs::c; std::vector<MatStorage> model_storage;
using UCConfig = typename Configs::uc; ModelWeightsPtrs<T> c_weights(config, pool);
// Allocate compressed weights. c_weights.Allocate(model_storage, pool);
using CWeights = CompressedWeights<CConfig>; ModelWeightsPtrs<float> uc_weights(config, pool);
ByteStorageT c_weights_u8 = AllocateCompressedWeights<CConfig>()(pool); uc_weights.Allocate(model_storage, pool);
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
// Allocate uncompressed weights.
using UCWeights = CompressedWeights<UCConfig>;
ByteStorageT uc_weights_u8 = AllocateCompressedWeights<UCConfig>()(pool);
UCWeights* uc_weights = reinterpret_cast<UCWeights*>(uc_weights_u8.get());
// Get uncompressed weights, compress, and store. // Get uncompressed weights, compress, and store.
FILE* fptr = fopen(weights_path.path.c_str(), "rb"); FILE* fptr = fopen(weights_path.path.c_str(), "rb");
if (fptr == nullptr) { if (fptr == nullptr) {
@ -181,22 +175,22 @@ void CompressWeights(const Path& weights_path,
} }
bool ok = true; bool ok = true;
uint64_t total_size = 0; uint64_t total_size = 0;
CompressedWeights<UCConfig>::ForEachTensor( ModelWeightsPtrs<float>::ForEachTensor(
{uc_weights}, ForEachType::kLoadNoToc, {&uc_weights}, ForEachType::kLoadNoToc,
[&](const char* name, hwy::Span<MatPtr*> tensors) { [&](const char* name, hwy::Span<MatPtr*> tensors) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", fprintf(stderr, "Loading Parameters (size %zu): %s\n",
tensors[0]->SizeBytes(), name); tensors[0]->SizeBytes(), name);
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
total_size += tensors[0]->SizeBytes(); total_size += tensors[0]->SizeBytes();
}); });
const bool scale_for_compression = UCConfig::kNumTensorScales > 0; const bool scale_for_compression = config.num_tensor_scales > 0;
std::vector<float> scales; std::vector<float> scales;
if (scale_for_compression) { if (scale_for_compression) {
uc_weights->GetOrApplyScales(scales); uc_weights.GetOrApplyScales(scales);
} }
Compressor compressor(pool); Compressor compressor(pool);
CompressedWeights<CConfig>::ForEachTensor( ModelWeightsPtrs<T>::ForEachTensor(
{reinterpret_cast<CompressedWeights<CConfig>*>(uc_weights), c_weights}, {reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
ForEachType::kLoadNoToc, ForEachType::kLoadNoToc,
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) { [&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
tensors[1]->CallUpcasted( tensors[1]->CallUpcasted(
@ -221,9 +215,26 @@ void Run(Args& args) {
HWY_ABORT("PaliGemma is not supported in compress_weights."); HWY_ABORT("PaliGemma is not supported in compress_weights.");
} }
const Type weight_type = args.WeightType(); const Type weight_type = args.WeightType();
GEMMA_EXPORT_AND_DISPATCH( switch (weight_type) {
model_type, weight_type, CompressWeights, case Type::kF32:
(args.weights, args.compressed_weights, model_type, weight_type, pool)); HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
(args.weights, args.compressed_weights, model_type, pool);
break;
case Type::kBF16:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
(args.weights, args.compressed_weights, model_type, pool);
break;
case Type::kSFP:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
(args.weights, args.compressed_weights, model_type, pool);
break;
case Type::kNUQ:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
(args.weights, args.compressed_weights, model_type, pool);
break;
default:
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
}
} }
} // namespace gcpp } // namespace gcpp

View File

@ -32,11 +32,6 @@ namespace gcpp {
using BF16 = hwy::bfloat16_t; using BF16 = hwy::bfloat16_t;
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 // Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format. // inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1 and decoding to bf16/f32. // It supports seeking at a granularity of 1 and decoding to bf16/f32.
@ -179,29 +174,67 @@ struct NuqStream {
}; };
#pragma pack(pop) #pragma pack(pop)
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
template <typename Packed>
constexpr bool IsBF16() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
}
template <typename Packed>
constexpr bool IsSfpStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>();
}
template <typename Packed>
constexpr bool IsNuqStream() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
}
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for
// ModelWeightsPtrs. When adding a new type that is supported, also
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "c64", "u128"};
// Returns a Type enum for the type of the template parameter.
template <typename PackedT> template <typename PackedT>
const char* TypeName() { Type TypeEnum() {
using Packed = hwy::RemoveCvRef<PackedT>; using Packed = hwy::RemoveCvRef<PackedT>;
if constexpr (hwy::IsSame<Packed, float>()) { if constexpr (hwy::IsSame<Packed, float>()) {
return "f32"; return Type::kF32;
} else if constexpr (hwy::IsSame<Packed, BF16>()) { } else if constexpr (hwy::IsSame<Packed, BF16>()) {
return "b16"; return Type::kBF16;
} else if constexpr (hwy::IsSame<Packed, SfpStream>()) { } else if constexpr (hwy::IsSame<Packed, SfpStream>()) {
return "sfp"; return Type::kSFP;
} else if constexpr (hwy::IsSame<Packed, NuqStream>()) { } else if constexpr (hwy::IsSame<Packed, NuqStream>()) {
return "nuq"; return Type::kNUQ;
} else if constexpr (hwy::IsSame<Packed, double>()) { } else if constexpr (hwy::IsSame<Packed, double>()) {
return "f64"; return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) { } else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
return "c64"; return Type::kC64;
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) { } else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
return "u128"; return Type::kU128;
} else { } else {
HWY_DASSERT(false); HWY_DASSERT(false);
return "unknown"; return Type::kUnknown;
} }
} }
// Returns a string name for the type of the template parameter.
template <typename PackedT>
const char* TypeName() {
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
}
template <typename Packed> template <typename Packed>
constexpr bool IsCompressed() { constexpr bool IsCompressed() {
return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>(); return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>();

View File

@ -128,8 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create( KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
env.GetModel()->Info().model, env.MutableConfig().prefill_tbatch_size); env.MutableConfig().prefill_tbatch_size);
float entropy = ComputeCrossEntropy( float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy; total_entropy += entropy;

View File

@ -69,8 +69,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
model_ = AllocateGemma(mutable_loader, pools_); model_ = AllocateGemma(mutable_loader, pools_);
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1); kv_caches_.resize(1);
kv_caches_[0] = kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
KVCache::Create(model_->Info().model, inference.prefill_tbatch_size); inference.prefill_tbatch_size);
} }
InitGenerator(inference, gen_); InitGenerator(inference, gen_);
runtime_config_ = { runtime_config_ = {
@ -163,7 +163,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
} }
for (size_t i = 1; i < num_queries; ++i) { for (size_t i = 1; i < num_queries; ++i) {
if (kv_caches_[i].seq_len == 0) { if (kv_caches_[i].seq_len == 0) {
kv_caches_[i] = KVCache::Create(model_->Info().model, kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
runtime_config_.prefill_tbatch_size); runtime_config_.prefill_tbatch_size);
} }
} }

View File

@ -103,8 +103,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const StreamFunc stream_token = [](int /*token*/, float) { return true; }; const StreamFunc stream_token = [](int /*token*/, float) { return true; };
// TWeight is unused, but we have to pass it to Config*. // TWeight is unused, but we have to pass it to Config*.
const int vocab_size = const int vocab_size = gemma.GetModelConfig().vocab_size;
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
float cross_entropy = std::log(vocab_size); // first token float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1; size_t pos = 1;

View File

@ -24,7 +24,6 @@
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
#include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "util/app.h" // LoaderArgs #include "util/app.h" // LoaderArgs
@ -58,7 +57,8 @@ int main(int argc, char** argv) {
gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin); gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin);
gcpp::Gemma model = gcpp::CreateGemma(loader, pools); gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
gcpp::KVCache kv_cache = gcpp::KVCache kv_cache =
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size); gcpp::KVCache::Create(model.GetModelConfig(),
inference.prefill_tbatch_size);
size_t generated = 0; size_t generated = 0;
// Initialize random number generator // Initialize random number generator

View File

@ -21,6 +21,7 @@
#include <cmath> #include <cmath>
#include "compression/shared.h" // BF16 #include "compression/shared.h" // BF16
#include "gemma/configs.h"
#include "ops/matmul.h" // MatMulEnv #include "ops/matmul.h" // MatMulEnv
#include "util/allocator.h" // RowVectorBatch #include "util/allocator.h" // RowVectorBatch
#include "util/threading.h" #include "util/threading.h"
@ -30,6 +31,12 @@
namespace gcpp { namespace gcpp {
struct Activations { struct Activations {
explicit Activations(const ModelConfig& config)
: weights_config(config),
layer_config(config.layer_configs[0]),
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()) {}
RowVectorBatch<float> x; // input RowVectorBatch<float> x; // input
RowVectorBatch<float> q; // query, also KV if MHA. RowVectorBatch<float> q; // query, also KV if MHA.
RowVectorBatch<float> logits; RowVectorBatch<float> logits;
@ -58,23 +65,24 @@ struct Activations {
MatMulEnv env; MatMulEnv env;
PostQKType post_qk = PostQKType::Rope;
// And the config.
const ModelConfig& weights_config;
const LayerConfig& layer_config;
size_t seq_len;
size_t cache_pos_size = 0;
// Multi-Head Attention? // Multi-Head Attention?
template <class TConfig> bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; }
static constexpr bool IsMHA() {
return TConfig::kHeads == TConfig::kKVHeads;
}
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
template <class TConfig> size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); }
static constexpr size_t QStride() {
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
}
template <class TConfig> static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
static RowVectorBatch<float> CreateInvTimescale() { PostQKType post_qk) {
constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t rope_dim =
const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim; post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(1, rope_dim / 2); RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) { for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const float freq_exponents = const float freq_exponents =
@ -86,40 +94,38 @@ struct Activations {
return inv_timescale; return inv_timescale;
} }
template <class TConfig>
void Allocate(size_t batch_size, PerClusterPools& pools) { void Allocate(size_t batch_size, PerClusterPools& pools) {
constexpr size_t kModelDim = TConfig::kModelDim; post_qk = layer_config.post_qk;
constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t model_dim = weights_config.model_dim;
constexpr size_t kHeads = TConfig::kHeads; const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; const size_t vocab_size = weights_config.vocab_size;
constexpr size_t kVocabSize = TConfig::kVocabSize;
constexpr size_t kSeqLen = TConfig::kSeqLen;
constexpr size_t kGriffinLayers = TConfig::kGriffinLayers;
x = RowVectorBatch<float>(batch_size, kModelDim); x = RowVectorBatch<float>(batch_size, model_dim);
q = RowVectorBatch<float>(batch_size, kHeads * QStride<TConfig>()); q = RowVectorBatch<float>(batch_size, layer_config.heads * QStride());
if constexpr (kVocabSize > 0) { if (vocab_size > 0) {
logits = RowVectorBatch<float>(batch_size, kVocabSize); logits = RowVectorBatch<float>(batch_size, vocab_size);
} }
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim); pre_att_rms_out = RowVectorBatch<float>(batch_size, model_dim);
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen); att = RowVectorBatch<float>(batch_size,
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim); layer_config.heads * weights_config.seq_len);
att_sums = RowVectorBatch<float>(batch_size, kModelDim); att_out = RowVectorBatch<float>(batch_size,
layer_config.heads * layer_config.qkv_dim);
att_sums = RowVectorBatch<float>(batch_size, model_dim);
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, kModelDim); bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, model_dim);
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim); C1 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim); C2 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
ffw_out = RowVectorBatch<float>(batch_size, kModelDim); ffw_out = RowVectorBatch<float>(batch_size, model_dim);
if constexpr (kGriffinLayers > 0) { if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
griffin_x = RowVectorBatch<float>(batch_size, kModelDim); griffin_x = RowVectorBatch<float>(batch_size, model_dim);
griffin_y = RowVectorBatch<float>(batch_size, kModelDim); griffin_y = RowVectorBatch<float>(batch_size, model_dim);
griffin_gate_x = RowVectorBatch<float>(batch_size, kModelDim); griffin_gate_x = RowVectorBatch<float>(batch_size, model_dim);
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim); griffin_multiplier = RowVectorBatch<float>(batch_size, model_dim);
} }
inv_timescale = CreateInvTimescale<TConfig>(); inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
env = MatMulEnv(pools); env = MatMulEnv(pools);
} }

View File

@ -15,6 +15,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include <math.h> // sqrtf
#include <stddef.h> #include <stddef.h>
#include <string.h> #include <string.h>
@ -23,6 +24,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/shared.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -101,8 +103,6 @@ const char* ModelString(Model model, ModelTraining training) {
static_cast<int>(training)); static_cast<int>(training));
} }
constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"};
const char* StringFromType(Type type) { const char* StringFromType(Type type) {
return kTypeStrings[static_cast<size_t>(type)]; return kTypeStrings[static_cast<size_t>(type)];
} }
@ -141,4 +141,19 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n"; prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
} }
} }
float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
sqrtf(static_cast<float>(model_dim))));
}
float ChooseQueryScale(const ModelConfig& config) {
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
config.layer_configs[0].heads));
// QueryScaleType::SqrtKeySize
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
}
} // namespace gcpp } // namespace gcpp

View File

@ -16,37 +16,15 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include <math.h> // sqrtf
#include <stddef.h> #include <stddef.h>
#include <string> #include <string>
#include "compression/compress.h"
#include "gemma/configs.h" // IWYU pragma: export #include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo #include "hwy/base.h" // ConvertScalarTo
namespace gcpp { namespace gcpp {
// Model variants: see configs.h for details. When adding a new one, also
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
enum class Model {
GEMMA_2B,
GEMMA_7B,
GEMMA2_9B,
GEMMA2_27B,
GRIFFIN_2B,
GEMMA_TINY,
GEMMA2_2B,
PALIGEMMA_224,
};
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
// Tensor types for loading weights. When adding a new one, also
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
enum class Type { kF32, kBF16, kSFP };
// TODO(janwas): merge with functions below. // TODO(janwas): merge with functions below.
struct ModelInfo { struct ModelInfo {
Model model; Model model;
@ -66,198 +44,12 @@ const char* StringFromType(Type type);
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
// but can also be called directly when FuncT does not actually use TWeight.
//
// Note that a T prefix indicates a concrete type template argument, whereas a
// T suffix indicates the argument is itself a template.
//
// `FuncT` must be a functor because function templates cannot be passed as a
// template template argument, and we prefer to avoid the overhead of
// std::function.
template <typename TWeight, template <typename TConfig> class FuncT,
typename... TArgs>
decltype(auto) CallForModel(Model model, TArgs&&... args) {
switch (model) {
case Model::GEMMA_TINY:
return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_2B:
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_7B:
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_9B:
return FuncT<ConfigGemma2_9B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_27B:
return FuncT<ConfigGemma2_27B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B:
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_2B:
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::PALIGEMMA_224:
return FuncT<ConfigPaliGemma_224<TWeight>>()(
std::forward<TArgs>(args)...);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
// Returns the return value of FuncT<TConfig>().operator()(args),
// where `TConfig` is selected based on `model` and `weight`.
// This makes it easy to extend `Model` or `Type` without updating callers.
//
// Usage example: LoadWeights is type-erased so that it can be called from other
// .cc files. It uses this function to call the appropriate instantiation of a
// template functor LoadCompressedWeightsT<TConfig>.
template <template <typename TConfig> class FuncT, typename... TArgs>
decltype(auto) CallForModelAndWeight(Model model, Type weight,
TArgs&&... args) {
switch (weight) {
case Type::kF32:
return CallForModel<float, FuncT, TArgs...>( //
model, std::forward<TArgs>(args)...);
case Type::kBF16:
return CallForModel<BF16, FuncT, TArgs...>(model,
std::forward<TArgs>(args)...);
case Type::kSFP:
return CallForModel<SfpStream, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
default:
HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight));
}
}
#define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \
X(CONFIGT, float) \
X(CONFIGT, BF16) \
X(CONFIGT, SfpStream)
#define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
GEMMA_FOREACH_WEIGHT(X, ConfigPaliGemma_224) \
static_assert(true, "Allow trailing ;")
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
using CP = ConfigPair<ConfigGemmaTiny<TWEIGHT>, ConfigGemmaTiny<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
using CP = ConfigPair<ConfigGemma2B<TWEIGHT>, ConfigGemma2B<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
using CP = ConfigPair<ConfigGemma7B<TWEIGHT>, ConfigGemma7B<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
using CP = ConfigPair<ConfigGriffin2B<TWEIGHT>, ConfigGriffin2B<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::GEMMA2_2B: { \
using CP = ConfigPair<ConfigGemma2_2B<TWEIGHT>, ConfigGemma2_2B<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::GEMMA2_9B: { \
using CP = ConfigPair<ConfigGemma2_9B<TWEIGHT>, ConfigGemma2_9B<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::GEMMA2_27B: { \
using CP = \
ConfigPair<ConfigGemma2_27B<TWEIGHT>, ConfigGemma2_27B<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
case Model::PALIGEMMA_224: { \
using CP = ConfigPair<ConfigPaliGemma_224<TWEIGHT>, \
ConfigPaliGemma_224<float>>; \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
// normal function argument. MODEL and WEIGHT are enums.
// For gemma.cc, we use overloaded extern functions for faster builds. However,
// this is still used in compress_weights because its compile time is OK.
#define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \
switch (WEIGHT) { \
case Type::kF32: \
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
break; \
case Type::kBF16: \
GEMMA_DISPATCH_MODEL(MODEL, BF16, FUNC, ARGS); \
break; \
case Type::kSFP: \
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
break; \
default: \
HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// //
// __builtin_sqrt is not constexpr as of Clang 17. float EmbeddingScaling(size_t model_dim);
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_SQRT constexpr
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
return __builtin_sqrt(x);
}
#else
#define GEMMA_CONSTEXPR_SQRT
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#endif
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` float ChooseQueryScale(const ModelConfig& config);
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(TConfig::kModelDim))));
}
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(model_dim))));
}
template <class TConfig>
GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() {
if (TConfig::kQueryScale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f /
Sqrt(static_cast<float>(TConfig::kModelDim / TConfig::kHeads));
// QueryScaleType::SqrtKeySize
return 1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
}
} // namespace gcpp } // namespace gcpp

246
gemma/configs.cc Normal file
View File

@ -0,0 +1,246 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/configs.h"
#include "hwy/base.h"
namespace gcpp {
static ModelConfig ConfigNoSSM() {
ModelConfig config = {.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
"gr_lin_y_w", "gr_lin_out_w",
"gr_gate_w", "gating_ein", "linear_w"}};
return config;
}
static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); }
static ModelConfig ConfigBaseGemmaV2() {
ModelConfig config = ConfigNoSSM();
config.att_cap = 50.0f;
config.final_cap = 30.0f;
return config;
}
static ModelConfig ConfigGemma2_27B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_27B";
config.model = Model::GEMMA2_27B;
config.model_dim = 4608;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
.heads = 32,
.kv_heads = 16,
.qkv_dim = 128,
.post_norm = PostNormType::Scale};
config.layer_configs = {46, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
config.attention_window_sizes =
RepeatedAttentionWindowSizes<46, 2>({4096, 8192});
return config;
}
static ModelConfig ConfigGemma2_9B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_9B";
config.model = Model::GEMMA2_9B;
config.model_dim = 3584;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
.heads = 16,
.kv_heads = 8,
.qkv_dim = 256,
.post_norm = PostNormType::Scale};
config.layer_configs = {42, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes =
RepeatedAttentionWindowSizes<42, 2>({4096, 8192});
return config;
}
static ModelConfig ConfigGemma2_2B() {
ModelConfig config = ConfigBaseGemmaV2();
config.model_name = "Gemma2_2B";
config.model = Model::GEMMA2_2B;
config.model_dim = 2304;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
.heads = 8,
.kv_heads = 4,
.qkv_dim = 256,
.post_norm = PostNormType::Scale};
config.layer_configs = {26, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes =
RepeatedAttentionWindowSizes<26, 2>({4096, 8192});
return config;
}
static ModelConfig ConfigGemma7B() {
ModelConfig config = ConfigBaseGemmaV1();
config.model_name = "Gemma7B";
config.model = Model::GEMMA_7B;
config.model_dim = 3072;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = gcpp::kSeqLen;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
.heads = 16,
.kv_heads = 16,
.qkv_dim = 256,
};
config.layer_configs = {28, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen);
return config;
}
static ModelConfig ConfigGemma2B() {
ModelConfig config = ConfigBaseGemmaV1();
config.model_name = "Gemma2B";
config.model = Model::GEMMA_2B;
config.model_dim = 2048;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = gcpp::kSeqLen;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
.heads = 8,
.kv_heads = 1,
.qkv_dim = 256,
};
config.layer_configs = {18, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen);
return config;
}
static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM();
config.model_name = "GemmaTiny";
config.model = Model::GEMMA_TINY;
config.model_dim = 128;
config.vocab_size = 64;
config.seq_len = 32;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 256,
.heads = 4,
.kv_heads = 1,
.qkv_dim = 16,
};
config.layer_configs = {3, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
// This is required for optimize_test to pass.
config.final_cap = 30.0f;
return config;
}
static ModelConfig ConfigGriffin2B() {
ModelConfig config = ConfigNoSSM();
config.model_name = "Griffin2B";
config.model = Model::GRIFFIN_2B;
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
config.model_dim = 2560;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = 2048;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.griffin_dim = config.model_dim,
.ff_hidden_dim = 7680,
.heads = 10,
.kv_heads = 1,
.qkv_dim = 256,
.conv1d_width = 4,
.ff_biases = true,
.softmax_attn_output_biases = true,
.type = LayerAttentionType::kGriffinRecurrentBlock,
.activation = ActivationType::Gelu,
.post_qk = PostQKType::Rope,
};
config.layer_configs = {26, layer_config};
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
config.layer_configs[i].type = LayerAttentionType::kGemma;
config.layer_configs[i].griffin_dim = 0;
}
config.num_tensor_scales = 140;
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
config.use_local_attention = true;
// This is required for optimize_test to pass.
config.final_cap = 0.0f;
return config;
}
static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.vit_seq_len = 16 * 16;
LayerConfig layer_config = {
.model_dim = config.vit_model_dim,
.ff_hidden_dim = 4304,
.heads = 16,
.kv_heads = 16,
.qkv_dim = 72,
.type = LayerAttentionType::kVit,
.patch_width = 14,
.image_size = 224,
};
config.vit_layer_configs = {27, layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
return config;
}
ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA_2B:
return ConfigGemma2B();
case Model::GEMMA_7B:
return ConfigGemma7B();
case Model::GEMMA2_2B:
return ConfigGemma2_2B();
case Model::GEMMA2_9B:
return ConfigGemma2_9B();
case Model::GEMMA2_27B:
return ConfigGemma2_27B();
case Model::GRIFFIN_2B:
return ConfigGriffin2B();
case Model::GEMMA_TINY:
return ConfigGemmaTiny();
case Model::PALIGEMMA_224:
return ConfigPaliGemma_224();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

View File

@ -21,6 +21,9 @@
#include <stddef.h> #include <stddef.h>
#include <array> #include <array>
#include <string>
#include <unordered_set>
#include <vector>
#include "compression/shared.h" // BF16 #include "compression/shared.h" // BF16
@ -57,6 +60,7 @@ enum class PostNormType {
// Post qk projection operation type. // Post qk projection operation type.
enum class PostQKType { enum class PostQKType {
Rope, Rope,
HalfRope,
}; };
// FFW activation function. // FFW activation function.
@ -76,358 +80,115 @@ enum class ResidualType {
}; };
template <size_t kNum> template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig( std::vector<LayerAttentionType> FixedLayerConfig(LayerAttentionType type) {
LayerAttentionType type) { return std::vector<LayerAttentionType>(kNum, type);
std::array<LayerAttentionType, kNum> config = {};
for (LayerAttentionType& l : config) {
l = type;
}
return config;
} }
template <size_t kNum> template <size_t kNum>
constexpr std::array<size_t, kNum> FixedAttentionWindowSizes( std::vector<size_t> FixedAttentionWindowSizes(size_t window_size) {
size_t window_size) { return std::vector<size_t>(kNum, window_size);
std::array<size_t, kNum> window_size_configs = {};
for (size_t& l : window_size_configs) {
l = window_size;
}
return window_size_configs;
} }
// Repeat window_size_pattern for kNum / kPatternSize times. // Repeat window_size_pattern for kNum / kPatternSize times.
template <size_t kNum, size_t kPatternSize> template <size_t kNum, size_t kPatternSize>
constexpr std::array<size_t, kNum> RepeatedAttentionWindowSizes( std::vector<size_t> RepeatedAttentionWindowSizes(
const std::array<size_t, kPatternSize>& window_size_pattern) { const std::array<size_t, kPatternSize>& window_size_pattern) {
static_assert(kNum % kPatternSize == 0, static_assert(kNum % kPatternSize == 0,
"kNum must be a multiple of kPatternSize"); "kNum must be a multiple of kPatternSize");
std::array<size_t, kNum> window_size_configs = {}; std::vector<size_t> window_size_configs(kNum);
for (size_t i = 0; i < kNum; ++i) { for (size_t i = 0; i < kNum; ++i) {
window_size_configs[i] = window_size_pattern[i % kPatternSize]; window_size_configs[i] = window_size_pattern[i % kPatternSize];
} }
return window_size_configs; return window_size_configs;
} }
template <size_t kNumLayers> // Model variants: see configs.cc for details.
constexpr size_t NumLayersOfTypeBefore( enum class Model {
const std::array<LayerAttentionType, kNumLayers>& layers, UNKNOWN,
LayerAttentionType type, size_t num) { GEMMA_2B,
GEMMA_7B,
GEMMA2_9B,
GEMMA2_27B,
GRIFFIN_2B,
GEMMA_TINY,
GEMMA2_2B,
PALIGEMMA_224,
};
struct LayerConfig {
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
size_t model_dim = 0;
size_t griffin_dim = 0;
size_t ff_hidden_dim = 0;
size_t heads = 0;
size_t kv_heads = 0;
size_t qkv_dim = 0;
size_t conv1d_width = 0;
bool ff_biases = false;
bool softmax_attn_output_biases = false;
PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;
PostQKType post_qk = PostQKType::Rope;
// Dimensions related to image processing.
int patch_width = 14;
int image_size = 224;
};
struct ModelConfig {
size_t CachePosSize() const {
size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize();
}
size_t NumLayersOfTypeBefore(LayerAttentionType type, size_t num) const {
size_t count = 0; size_t count = 0;
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++; if (layer_configs[i].type == type) ++count;
} }
return count; return count;
}
template <class TConfig, typename = void>
struct CacheLayerSize {
constexpr size_t operator()() const {
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
} }
};
template <class TConfig, typename = void> size_t NumLayersOfType(LayerAttentionType type) const {
struct CachePosSize { return NumLayersOfTypeBefore(type, layer_configs.size());
constexpr size_t operator()() const {
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
} }
size_t NumHeads() const {
size_t num_heads = 0;
for (const auto& layer_config : layer_configs) {
num_heads = std::max(num_heads, layer_config.heads);
}
return num_heads;
}
std::string model_name;
Model model;
ModelTraining training;
Type weight;
size_t model_dim = 0;
size_t vit_model_dim = 0;
size_t vocab_size = 0;
size_t seq_len = 0;
size_t vit_seq_len = 0;
size_t num_tensor_scales = 0;
size_t num_vit_scales = 0;
size_t top_k = kTopK;
float att_cap = 0.0f;
float final_cap = 0.0f;
bool absolute_pe = false;
bool use_local_attention = false;
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
std::vector<LayerConfig> layer_configs;
std::vector<size_t> attention_window_sizes;
std::vector<LayerConfig> vit_layer_configs;
std::unordered_set<std::string> scale_names;
int norm_num_groups = 1;
int model_family_version = 1;
}; };
struct ConfigNoVit { // Returns the config for the given model.
struct VitConfig { ModelConfig ConfigFromModel(Model model);
// Some of these are needed to make the compiler happy when trying to
// generate code that will actually never be used.
using Weight = float;
static constexpr int kLayers = 0;
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
FixedLayerConfig<0>(LayerAttentionType::kVit);
static constexpr int kModelDim = 0;
static constexpr int kFFHiddenDim = 0;
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
static constexpr int kKVHeads = 0;
static constexpr int kQKVDim = 0;
static constexpr int kSeqLen = 0;
static constexpr ResidualType kResidual = ResidualType::Add;
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
struct ConfigNoSSM : ConfigNoVit {
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr ResidualType kResidual = ResidualType::Add;
};
struct ConfigBaseGemmaV1 : ConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
struct ConfigBaseGemmaV2 : ConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};
template <typename TWeight>
struct ConfigGemma2_27B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
FixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
static constexpr int kHeads = 32;
static constexpr int kKVHeads = 16;
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
};
template <typename TWeight>
struct ConfigGemma2_9B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
FixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 8;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct ConfigGemma7B : public ConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
FixedAttentionWindowSizes<28>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct ConfigGemma2B : public ConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
FixedAttentionWindowSizes<18>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct ConfigPaliGemma_224 : public ConfigGemma2B<TWeight> {
// On the LM side, the vocab size is one difference to Gemma1-2B in the
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
// Sub-config for the Vision-Transformer part.
struct VitConfig : public ConfigNoSSM {
using Weight = TWeight;
// The ViT parts. https://arxiv.org/abs/2305.13035
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
FixedLayerConfig<27>(LayerAttentionType::kVit);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kModelDim = 1152;
static constexpr int kFFHiddenDim = 4304;
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 72;
static constexpr int kSeqLen = 16 * 16; // 256
static constexpr bool kFFBiases = true;
// The Vit part does not have a vocabulary, the image patches are embedded.
static constexpr int kVocabSize = 0;
// Dimensions related to image processing.
static constexpr int kPatchWidth = 14;
static constexpr int kImageSize = 224;
// Necessary constant for the layer configuration.
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
template <typename TWeight>
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
FixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
FixedAttentionWindowSizes<3>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
};
template <typename TWeight>
struct ConfigGriffin2B : ConfigNoVit {
using Weight = TWeight; // make accessible where we only have a TConfig
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
FixedAttentionWindowSizes<26>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;
};
} // namespace gcpp } // namespace gcpp

445
gemma/configs_test.cc Normal file
View File

@ -0,0 +1,445 @@
#include "gemma/configs.h"
#include <array>
#include <cstddef>
#include <type_traits>
#include "gtest/gtest.h"
namespace gcpp {
template <size_t kNum>
constexpr std::array<LayerAttentionType, kNum> OldFixedLayerConfig(
LayerAttentionType type) {
std::array<LayerAttentionType, kNum> config = {};
for (LayerAttentionType& l : config) {
l = type;
}
return config;
}
template <size_t kNum>
constexpr std::array<size_t, kNum> OldFixedAttentionWindowSizes(
size_t window_size) {
std::array<size_t, kNum> window_size_configs = {};
for (size_t& l : window_size_configs) {
l = window_size;
}
return window_size_configs;
}
// Repeat window_size_pattern for kNum / kPatternSize times.
template <size_t kNum, size_t kPatternSize>
constexpr std::array<size_t, kNum> OldRepeatedAttentionWindowSizes(
const std::array<size_t, kPatternSize>& window_size_pattern) {
static_assert(kNum % kPatternSize == 0,
"kNum must be a multiple of kPatternSize");
std::array<size_t, kNum> window_size_configs = {};
for (size_t i = 0; i < kNum; ++i) {
window_size_configs[i] = window_size_pattern[i % kPatternSize];
}
return window_size_configs;
}
template <size_t kNumLayers>
constexpr size_t OldNumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
LayerAttentionType type, size_t num) {
size_t count = 0;
for (size_t i = 0; i < num; i++) {
if (layers[i] == type) count++;
}
return count;
}
template <class TConfig, typename = void>
struct CacheLayerSize {
constexpr size_t operator()() const {
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
}
};
template <class TConfig, typename = void>
struct CachePosSize {
constexpr size_t operator()() const {
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
}
};
struct OldConfigNoVit {
struct VitConfig {
// Some of these are needed to make the compiler happy when trying to
// generate code that will actually never be used.
using Weight = float;
static constexpr int kLayers = 0;
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
OldFixedLayerConfig<0>(LayerAttentionType::kVit);
static constexpr int kModelDim = 0;
static constexpr int kFFHiddenDim = 0;
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
static constexpr int kKVHeads = 0;
static constexpr int kQKVDim = 0;
static constexpr int kSeqLen = 0;
static constexpr ResidualType kResidual = ResidualType::Add;
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
struct OldConfigNoSSM : OldConfigNoVit {
static constexpr int kGriffinLayers = 0;
static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr ResidualType kResidual = ResidualType::Add;
};
struct OldConfigBaseGemmaV1 : OldConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
struct OldConfigBaseGemmaV2 : OldConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};
template <typename TWeight>
struct OldConfigGemma2_27B : public OldConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
OldFixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
OldRepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
static constexpr int kHeads = 32;
static constexpr int kKVHeads = 16;
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
};
template <typename TWeight>
struct OldConfigGemma2_9B : public OldConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
OldFixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
OldRepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 8;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct OldConfigGemma7B : public OldConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
OldFixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<28>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct OldConfigGemma2B : public OldConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
OldFixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<18>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};
template <typename TWeight>
struct OldConfigPaliGemma_224 : public OldConfigGemma2B<TWeight> {
// On the LM side, the vocab size is one difference to Gemma1-2B in the
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
// Sub-config for the Vision-Transformer part.
struct VitConfig : public OldConfigNoSSM {
using Weight = TWeight;
// The ViT parts. https://arxiv.org/abs/2305.13035
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
OldFixedLayerConfig<27>(LayerAttentionType::kVit);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kModelDim = 1152;
static constexpr int kFFHiddenDim = 4304;
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 72;
static constexpr int kSeqLen = 16 * 16; // 256
static constexpr bool kFFBiases = true;
// The Vit part does not have a vocabulary, the image patches are embedded.
static constexpr int kVocabSize = 0;
// Dimensions related to image processing.
static constexpr int kPatchWidth = 14;
static constexpr int kImageSize = 224;
// Necessary constant for the layer configuration.
static constexpr PostNormType kPostNorm = PostNormType::None;
};
};
template <typename TWeight>
struct OldConfigGemma2_2B : public OldConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
OldFixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
OldRepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
struct OldConfigGemmaTiny : public OldConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
OldFixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<3>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kNumTensorScales = 4 * kLayers;
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
};
template <typename TWeight>
struct OldConfigGriffin2B : OldConfigNoVit {
using Weight = TWeight; // make accessible where we only have a TConfig
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = gcpp::kVocabSize;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
OldFixedAttentionWindowSizes<26>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = OldNumLayersOfTypeBefore(
kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers = OldNumLayersOfTypeBefore(
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;
};
template <class TConfig>
void AssertMatch(const ModelConfig& config) {
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
if constexpr (TConfig::VitConfig::kModelDim != 0) {
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
config.vit_layer_configs[i].type);
}
}
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
ASSERT_EQ(TConfig::kTopK, config.top_k);
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
ASSERT_EQ(TConfig::kUseLocalAttention, config.use_local_attention);
ASSERT_EQ(TConfig::kQueryScale, config.query_scale);
ASSERT_EQ(TConfig::kGemmaLayers,
config.NumLayersOfType(LayerAttentionType::kGemma));
ASSERT_EQ(TConfig::kGriffinLayers,
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock));
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
ASSERT_EQ(TConfig::kModelDim, config.layer_configs[i].model_dim);
ASSERT_EQ(TConfig::kFFHiddenDim, config.layer_configs[i].ff_hidden_dim);
ASSERT_EQ(TConfig::kHeads, config.layer_configs[i].heads);
ASSERT_EQ(TConfig::kKVHeads, config.layer_configs[i].kv_heads);
ASSERT_EQ(TConfig::kQKVDim, config.layer_configs[i].qkv_dim);
ASSERT_EQ(TConfig::kConv1dWidth, config.layer_configs[i].conv1d_width);
ASSERT_EQ(TConfig::kFFBiases, config.layer_configs[i].ff_biases);
ASSERT_EQ(TConfig::kSoftmaxAttnOutputBiases,
config.layer_configs[i].softmax_attn_output_biases);
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
ASSERT_EQ(TConfig::kPostQK, config.layer_configs[i].post_qk);
}
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
config.attention_window_sizes.size());
for (size_t i = 0; i < config.attention_window_sizes.size(); ++i) {
ASSERT_EQ(TConfig::kAttentionWindowSizes[i],
config.attention_window_sizes[i]);
}
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
}
TEST(ConfigsTest, OldConfigGemma2B) {
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
}
TEST(ConfigsTest, OldConfigGemma7B) {
AssertMatch<OldConfigGemma7B<float>>(ConfigFromModel(Model::GEMMA_7B));
}
TEST(ConfigsTest, OldConfigGemma2_2B) {
AssertMatch<OldConfigGemma2_2B<float>>(ConfigFromModel(Model::GEMMA2_2B));
}
TEST(ConfigsTest, OldConfigGemma2_9B) {
AssertMatch<OldConfigGemma2_9B<float>>(ConfigFromModel(Model::GEMMA2_9B));
}
TEST(ConfigsTest, OldConfigGemma2_27B) {
AssertMatch<OldConfigGemma2_27B<float>>(ConfigFromModel(Model::GEMMA2_27B));
}
TEST(ConfigsTest, OldConfigGriffin2B) {
AssertMatch<OldConfigGriffin2B<float>>(ConfigFromModel(Model::GRIFFIN_2B));
}
TEST(ConfigsTest, OldConfigGemmaTiny) {
AssertMatch<OldConfigGemmaTiny<float>>(ConfigFromModel(Model::GEMMA_TINY));
}
TEST(ConfigsTest, OldConfigPaliGemma_224) {
AssertMatch<OldConfigPaliGemma_224<float>>(
ConfigFromModel(Model::PALIGEMMA_224));
}
} // namespace gcpp

File diff suppressed because it is too large Load Diff

View File

@ -29,88 +29,90 @@
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "ops/ops-inl.h"
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/threading.h" #include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/profiler.h" // also uses SIMD
namespace gcpp { namespace gcpp {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, PerClusterPools& pools) const ModelInfo& info, PerClusterPools& pools)
: pools_(pools), tokenizer_(tokenizer_path), info_(info) { : pools_(pools), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = model_.Load(weights, info.model, info.weight, pools_.Inner(0));
LoadCompressedWeights(weights, info.model, info.weight, pools_.Inner(0));
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
PerClusterPools& pools) PerClusterPools& pools)
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) { : pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
HWY_ASSERT(info.weight == Type::kF32); HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(info.model, model_.Allocate(info.model, info.weight, pools_.Inner(0));
pools_.Inner(0));
} }
Gemma::~Gemma() { Gemma::~Gemma() {
} }
// There are >100 instantiations of the inference code. To reduce compile time, // There are >=3 types of the inference code. To reduce compile time,
// we shard them across multiple translation units in instantiations/*.cc. // we shard them across multiple translation units in instantiations/*.cc.
// This declares the functions defined there. We use overloading because // This declares the functions defined there. We use overloading because
// explicit instantiations are still too slow to compile. // explicit instantiations are still too slow to compile.
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ #define GEMMA_DECLARE(TWEIGHT) \
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \ extern void GenerateSingle(TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, \ const RuntimeConfig& runtime_config, \
const PromptTokens& prompt, size_t pos, \ const PromptTokens& prompt, size_t pos, \
size_t prefix_end, KVCache& kv_cache, \ size_t prefix_end, KVCache& kv_cache, \
PerClusterPools& pools, TimingInfo& timing_info); \ PerClusterPools& pools, TimingInfo& timing_info); \
extern void GenerateBatch( \ extern void GenerateBatch( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \ TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, \ const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, \ const KVCaches& kv_caches, PerClusterPools& pools, \
PerClusterPools& pools, TimingInfo& timing_info); \ TimingInfo& timing_info); \
extern void GenerateImageTokens( \ extern void GenerateImageTokens( \
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \ TWEIGHT, const ModelWeightsStorage& model, \
const RuntimeConfig& runtime_config, const Image& image, \ const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, PerClusterPools& pools); ImageTokens& image_tokens, PerClusterPools& pools);
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); GEMMA_DECLARE(float)
GEMMA_DECLARE(BF16)
GEMMA_DECLARE(NuqStream)
GEMMA_DECLARE(SfpStream)
// Adapters to select from the above overloads via CallForModelAndWeight. // Adapters to select from the above overloads via CallForModelWeight.
template <class TConfig> template <class TConfig>
struct GenerateSingleT { struct GenerateSingleT {
void operator()(const ByteStorageT& weights_u8, void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end, const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, PerClusterPools& pools, KVCache& kv_cache, PerClusterPools& pools,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos, GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
prefix_end, kv_cache, pools, timing_info); kv_cache, pools, timing_info);
} }
}; };
template <class TConfig> template <class TConfig>
struct GenerateBatchT { struct GenerateBatchT {
void operator()(const ByteStorageT& weights_u8, void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, PerClusterPools& pools, const KVCaches& kv_caches, PerClusterPools& pools,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {
GenerateBatch(TConfig(), weights_u8, runtime_config, queries_prompt, GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
queries_pos, queries_prefix_end, kv_caches, pools, queries_prefix_end, kv_caches, pools, timing_info);
timing_info);
} }
}; };
template <class TConfig> template <class TConfig>
struct GenerateImageTokensT { struct GenerateImageTokensT {
void operator()(const ByteStorageT& weights_u8, void operator()(const ModelWeightsStorage& model,
const RuntimeConfig& runtime_config, const Image& image, const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, PerClusterPools& pools) const { ImageTokens& image_tokens, PerClusterPools& pools) const {
GenerateImageTokens(TConfig(), weights_u8, runtime_config, image, GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
image_tokens, pools); pools);
} }
}; };
@ -119,9 +121,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
KVCache& kv_cache, TimingInfo& timing_info) { KVCache& kv_cache, TimingInfo& timing_info) {
if (runtime_config.use_spinning) pools_.StartSpinning(); if (runtime_config.use_spinning) pools_.StartSpinning();
CallForModelAndWeight<GenerateSingleT>( model_.CallForModelWeight<GenerateSingleT>(
info_.model, info_.weight, weights_u8_, runtime_config, prompt, pos, runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info);
prefix_end, kv_cache, pools_, timing_info);
if (runtime_config.use_spinning) pools_.StopSpinning(); if (runtime_config.use_spinning) pools_.StopSpinning();
} }
@ -142,9 +143,9 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
if (runtime_config.use_spinning) pools_.StartSpinning(); if (runtime_config.use_spinning) pools_.StartSpinning();
CallForModelAndWeight<GenerateBatchT>( model_.CallForModelWeight<GenerateBatchT>(
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt, runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
queries_pos, mutable_queries_prefix_end, kv_caches, pools_, timing_info); kv_caches, pools_, timing_info);
if (runtime_config.use_spinning) pools_.StopSpinning(); if (runtime_config.use_spinning) pools_.StopSpinning();
} }
@ -153,28 +154,25 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens) { const Image& image, ImageTokens& image_tokens) {
if (runtime_config.use_spinning) pools_.StartSpinning(); if (runtime_config.use_spinning) pools_.StartSpinning();
CallForModelAndWeight<GenerateImageTokensT>(info_.model, info_.weight, model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
weights_u8_, runtime_config, image_tokens, pools_);
image, image_tokens, pools_);
if (runtime_config.use_spinning) pools_.StopSpinning(); if (runtime_config.use_spinning) pools_.StopSpinning();
} }
template <typename TConfig> // Non-template functions moved from gemma-inl.h to avoid ODR violations.
struct GetModelConfig {
ModelConfigInfo operator()() const {
return ModelConfigInfo{
.layers = TConfig::kLayers,
.model_dim = TConfig::kModelDim,
.heads = TConfig::kHeads,
.kv_heads = TConfig::kKVHeads,
.qkv_dim = TConfig::kQKVDim,
};
}
};
ModelConfigInfo Gemma::ModelConfig() const { void RangeChecks(const ModelConfig& weights_config,
return CallForModel<float, GetModelConfig>(info_.model); size_t& max_generated_tokens, const size_t prompt_size) {
if (!weights_config.use_local_attention) {
if (max_generated_tokens > weights_config.seq_len) {
fprintf(stderr,
"WARNING: max_generated_tokens %zu > kSeqLen %zu, truncating.\n",
max_generated_tokens, weights_config.seq_len);
max_generated_tokens = weights_config.seq_len;
}
}
HWY_ASSERT(prompt_size > 0);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -27,6 +27,7 @@
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include "gemma/tokenizer.h" #include "gemma/tokenizer.h"
#include "gemma/weights.h"
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/allocator.h" // RowVectorBatch #include "util/allocator.h" // RowVectorBatch
#include "util/basics.h" // TokenAndProb #include "util/basics.h" // TokenAndProb
@ -179,15 +180,6 @@ struct TimingInfo {
size_t tokens_generated = 0; size_t tokens_generated = 0;
}; };
// ModelConfigInfo holds model configuration details: number of layers, etc.
struct ModelConfigInfo {
const int layers;
const int model_dim;
const int heads;
const int kv_heads;
const int qkv_dim;
};
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
@ -198,11 +190,11 @@ class Gemma {
PerClusterPools& pools); PerClusterPools& pools);
~Gemma(); ~Gemma();
ModelConfigInfo ModelConfig() const; const ModelConfig& GetModelConfig() const { return model_.Config(); }
const ModelInfo& Info() const { return info_; } const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; } const ModelWeightsStorage& Weights() const { return model_; }
ByteStorageT& MutableWeights() { return weights_u8_; } ModelWeightsStorage& MutableWeights() { return model_; }
// `pos` is the position in the KV cache. Users are responsible for // `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn. // incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
@ -241,7 +233,7 @@ class Gemma {
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header. // Type-erased so that this can be defined in the header.
ByteStorageT weights_u8_; ModelWeightsStorage model_;
ModelInfo info_; ModelInfo info_;
}; };
@ -251,6 +243,8 @@ class Gemma {
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos, const ModelInfo& info, size_t pos,
std::string& prompt); std::string& prompt);
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, size_t prompt_size);
} // namespace gcpp } // namespace gcpp

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/27b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_27B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/27b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_27B<float>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/27b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_27B<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/2b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/7b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma7B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/7b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma7B<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/9b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_9B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/9b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_9B<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -14,8 +14,7 @@
// limitations under the License. // limitations under the License.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
"gemma/instantiations/2b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2B<float> #define GEMMA_TYPE hwy::bfloat16_t
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -14,8 +14,7 @@
// limitations under the License. // limitations under the License.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE "gemma/instantiations/f32.cc"
"gemma/instantiations/7b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma7B<float> #define GEMMA_TYPE float
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<float>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gr2b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGriffin2B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gr2b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGriffin2B<float>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gr2b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGriffin2B<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -14,8 +14,7 @@
// limitations under the License. // limitations under the License.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE "gemma/instantiations/nuq.cc"
"gemma/instantiations/9b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_9B<float> #define GEMMA_TYPE NuqStream
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/paligemma_224_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigPaliGemma_224<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/paligemma_224_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigPaliGemma_224<float>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/paligemma_224_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigPaliGemma_224<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -14,8 +14,7 @@
// limitations under the License. // limitations under the License.
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE "gemma/instantiations/sfp.cc"
"gemma/instantiations/2b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2B<SfpStream> #define GEMMA_TYPE SfpStream
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/tiny_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemmaTiny<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/tiny_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemmaTiny<float>
#include "gemma/gemma-inl.h"

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/tiny_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemmaTiny<SfpStream>
#include "gemma/gemma-inl.h"

View File

@ -15,32 +15,40 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include <algorithm>
#include "gemma/common.h" // CallForModel #include "gemma/common.h" // CallForModel
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes #include "hwy/base.h" // ZeroBytes
namespace gcpp { namespace gcpp {
namespace {
template <class TConfig> // prefill_tbatch_size is the maximum number of tokens from one query to
struct CreateKVCache { // prefill at a time.
KVCache operator()(size_t prefill_tbatch_size) const { KVCache KVCache::Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache kv_cache = {}; KVCache kv_cache = {};
const size_t size_cache_pos = CachePosSize<TConfig>()(); const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
// Allocate more so that prefill can always access one batch, even if // Allocate more so that prefill can always access one batch, even if
// near the end of the sequence. // near the end of the sequence.
kv_cache.seq_len = TConfig::kSeqLen + prefill_tbatch_size; kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size;
kv_cache.kv_cache = kv_cache.kv_cache =
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos); hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
} }
size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
// TODO(patrickms): Add query batching support for Griffin. // TODO(patrickms): Add query batching support for Griffin.
if (TConfig::kGriffinLayers) { if (num_griffin_layers > 0) {
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; size_t conv1d_width = 0;
for (const auto& layer_config : weights_config.layer_configs) {
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
}
const size_t conv1d_cache_size = const size_t conv1d_cache_size =
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) * num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
TConfig::kModelDim; weights_config.model_dim;
if (conv1d_cache_size != 0) { if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size); kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
hwy::ZeroBytes(kv_cache.conv1d_cache.get(), hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
@ -48,7 +56,7 @@ struct CreateKVCache {
} }
const size_t rglru_cache_size = const size_t rglru_cache_size =
TConfig::kGriffinLayers * TConfig::kModelDim; num_griffin_layers * weights_config.model_dim;
if (rglru_cache_size != 0) { if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size); kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
hwy::ZeroBytes(kv_cache.rglru_cache.get(), hwy::ZeroBytes(kv_cache.rglru_cache.get(),
@ -57,17 +65,6 @@ struct CreateKVCache {
} // kGriffinLayers } // kGriffinLayers
return kv_cache; return kv_cache;
}
};
} // namespace
// prefill_tbatch_size is the maximum number of tokens from one query to
// prefill at a time.
KVCache KVCache::Create(Model model_type, size_t prefill_tbatch_size) {
// TWeight=float is a placeholder and unused because CreateKVCache does not
// use TConfig::Weight.
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type,
prefill_tbatch_size);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -35,7 +35,8 @@ struct KVCache {
// kModelDim * kGriffinLayers // kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> rglru_cache; hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
static KVCache Create(Model type, size_t prefill_tbatch_size); static KVCache Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size);
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -194,7 +194,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
Gemma model = CreateGemma(loader, pools); Gemma model = CreateGemma(loader, pools);
KVCache kv_cache = KVCache kv_cache =
KVCache::Create(model.Info().model, inference.prefill_tbatch_size); KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
std::string instructions = std::string instructions =

View File

@ -17,12 +17,14 @@
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <memory>
#include <random>
#include <vector> #include <vector>
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/common.h" #include "gemma/common.h"
#include "util/allocator.h" #include "gemma/configs.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -31,58 +33,128 @@
namespace gcpp { namespace gcpp {
namespace { template <typename T>
template <class TConfig> struct TensorLoader {
struct LoadCompressedWeightsT { void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
ByteStorageT operator()(const Path& weights, hwy::ThreadPool& pool) const { CacheLoader& loader) {
PROFILER_ZONE("Startup.LoadCompressedWeights"); weights.ForEachTensor(
{&weights}, fet,
[&loader](const char* name, hwy::Span<MatPtr*> tensors) {
loader(name, tensors);
});
}
};
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
if (!weights.Exists()) { if (!weights.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.", HWY_ABORT("The model weights file '%s' does not exist.",
weights.path.c_str()); weights.path.c_str());
} }
// Allocate compressed weights.
using CWeights = CompressedWeights<TConfig>;
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (c_weights) CWeights(pool);
CacheLoader loader(weights); CacheLoader loader(weights);
ForEachType fet = ForEachType fet =
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc; loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
CWeights::ForEachTensor( if (fet == ForEachType::kLoadWithToc) {
{c_weights}, fet, // TODO(rays): Load the config from the file.
[&loader](const char* name, hwy::Span<MatPtr*> tensors) { HWY_ABORT("TOC not supported yet.");
loader(name, tensors); } else {
}); // No Toc-> no config.
std::vector<float> scales(TConfig::kNumTensorScales); config_ = ConfigFromModel(model_type);
if (TConfig::kNumTensorScales > 0) { config_.weight = weight_type;
}
CreateForType(weight_type, pool);
CallForModelWeightT<TensorLoader>(fet, loader);
std::vector<float> scales(config_.num_tensor_scales + config_.num_vit_scales);
if (!scales.empty()) {
loader.LoadScales(scales.data(), scales.size()); loader.LoadScales(scales.data(), scales.size());
} }
if (!loader.ReadAll(pool, c_weights->model_storage)) { BlobError err = loader.ReadAll(pool, model_storage_);
HWY_ABORT("Failed to load model weights."); if (err != 0) {
fprintf(stderr, "Failed to load model weights: %d\n", err);
return err;
} }
if (TConfig::kNumTensorScales > 0) { if (!scales.empty()) {
c_weights->GetOrApplyScales(scales); GetOrApplyScales(scales);
} }
{ if (fet == ForEachType::kLoadNoToc) {
PROFILER_ZONE("Startup.Reshape"); PROFILER_ZONE("Startup.Reshape");
c_weights->Reshape(pool); AllocAndCopyWithTranspose(pool);
} }
return c_weights_u8; return 0;
} }
};
} // namespace
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type, void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
Type weight_type, hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return CallForModelAndWeight<LoadCompressedWeightsT>(model_type, weight_type, PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");
weights, pool); config_ = config;
config_.weight = weight_type;
CreateForType(weight_type, pool);
if (float_weights_) float_weights_->Allocate(model_storage_, pool);
if (bf16_weights_) bf16_weights_->Allocate(model_storage_, pool);
if (sfp_weights_) sfp_weights_->Allocate(model_storage_, pool);
if (nuq_weights_) nuq_weights_->Allocate(model_storage_, pool);
}
class WeightInitializer {
public:
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
float* data = tensors[0]->data<float>();
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
data[i] = dist_(gen_);
}
tensors[0]->set_scale(1.0f);
}
private:
std::normal_distribution<float> dist_;
std::mt19937& gen_;
};
void ModelWeightsStorage::RandInit(std::mt19937& gen) {
HWY_ASSERT(float_weights_);
WeightInitializer init(gen);
ModelWeightsPtrs<float>::ForEachTensor({float_weights_.get()},
ForEachType::kLoadNoToc, init);
}
void ModelWeightsStorage::ZeroInit() {
if (float_weights_) float_weights_->ZeroInit();
if (bf16_weights_) bf16_weights_->ZeroInit();
if (sfp_weights_) sfp_weights_->ZeroInit();
if (nuq_weights_) nuq_weights_->ZeroInit();
}
void ModelWeightsStorage::GetOrApplyScales(std::vector<float>& scales) {
if (float_weights_) float_weights_->GetOrApplyScales(scales);
if (bf16_weights_) bf16_weights_->GetOrApplyScales(scales);
if (sfp_weights_) sfp_weights_->GetOrApplyScales(scales);
if (nuq_weights_) nuq_weights_->GetOrApplyScales(scales);
}
void ModelWeightsStorage::AllocAndCopyWithTranspose(hwy::ThreadPool& pool) {
if (float_weights_)
float_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
if (bf16_weights_)
bf16_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
if (sfp_weights_)
sfp_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
if (nuq_weights_)
nuq_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
}
void ModelWeightsStorage::CopyWithTranspose(hwy::ThreadPool& pool) {
if (float_weights_) float_weights_->CopyWithTranspose(pool);
if (bf16_weights_) bf16_weights_->CopyWithTranspose(pool);
if (sfp_weights_) sfp_weights_->CopyWithTranspose(pool);
if (nuq_weights_) nuq_weights_->CopyWithTranspose(pool);
} }
namespace { namespace {
// For reasons unknown, this is shown as potentially unused in the IDE.
void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) { void LogVec(const char* name, const float* data, size_t len) {
hwy::Stats stats; hwy::Stats stats;
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
stats.Notify(data[i]); stats.Notify(data[i]);
@ -91,36 +163,44 @@ void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) {
name, len, stats.Min(), stats.Mean(), stats.Max()); name, len, stats.Min(), stats.Mean(), stats.Max());
} }
class WeightLogger { } // namespace
public:
void operator()(const char* name, hwy::Span<MatPtr*> tensors) { void ModelWeightsStorage::LogWeightStats() {
size_t total_weights = 0;
// Only for float weights.
ModelWeightsPtrs<float>::ForEachTensor(
{float_weights_.get()}, ForEachType::kInitNoToc,
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
const MatPtr& tensor = *tensors[0]; const MatPtr& tensor = *tensors[0];
if (tensor.scale() != 1.0f) { if (tensor.scale() != 1.0f) {
printf("[scale=%f] ", tensor.scale()); printf("[scale=%f] ", tensor.scale());
} }
LogVec(name, tensor.data<float>(), tensor.NumElements()); LogVec(name, tensor.data<float>(), tensor.NumElements());
total_weights += tensor.NumElements(); total_weights += tensor.NumElements();
} });
size_t total_weights = 0; printf("%-20s %12zu\n", "Total", total_weights);
}; }
template <typename TConfig> void ModelWeightsStorage::CreateForType(Type weight_type,
struct LogWeightStatsT { hwy::ThreadPool& pool) {
void operator()(const ByteStorageT& weights_u8) const { switch (weight_type) {
auto& weights = case Type::kF32:
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get()); float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_, pool);
WeightLogger logger; break;
CompressedWeights<TConfig>::ForEachTensor( case Type::kBF16:
{&weights}, ForEachType::kIgnoreNulls, logger); bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_, pool);
printf("%-20s %12zu\n", "Total", logger.total_weights); break;
case Type::kSFP:
sfp_weights_ =
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_, pool);
break;
case Type::kNUQ:
nuq_weights_ =
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_, pool);
break;
default:
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
} }
};
} // namespace
void LogWeightStats(gcpp::Model model_type, Type weight_type,
const ByteStorageT& weights) {
HWY_ASSERT(weight_type == Type::kF32);
CallForModel<float, LogWeightStatsT>(model_type, weights);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -18,9 +18,10 @@
#include <stddef.h> #include <stddef.h>
#include <array>
#include <complex> #include <complex>
#include <cstdio> #include <cstdio>
#include <memory>
#include <random>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -29,7 +30,6 @@
#include "compression/shared.h" #include "compression/shared.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -53,57 +53,79 @@ enum class ForEachType {
kInitNoToc, kInitNoToc,
}; };
template <class TConfig> template <class Weight>
struct CompressedLayer { struct LayerWeightsPtrs {
// Large data is constructed separately. // Large data is constructed separately.
CompressedLayer() explicit LayerWeightsPtrs(const LayerConfig& config)
: attn_vec_einsum_w("att_ein", kModelDim, kHeads * kQKVDim), : attn_vec_einsum_w("att_ein", config.model_dim,
qkv_einsum_w("qkv_ein", (kHeads + 2 * kKVHeads) * kQKVDim, kModelDim), config.heads * config.qkv_dim),
qkv_einsum_w1("qkv1_w", kHeads * kQKVDim, kModelDim), qkv_einsum_w("qkv_ein",
qkv_einsum_w2("qkv2_w", 2 * kKVHeads * kQKVDim, kModelDim), (config.heads + 2 * config.kv_heads) * config.qkv_dim,
attention_output_biases("attn_ob", 1, kAOBiasDim), config.model_dim),
griffin({.linear_x_w = {"gr_lin_x_w", kGriffinDim, kGriffinDim}, qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim,
.linear_x_biases = {"gr_lin_x_b", 1, kGriffinDim}, config.model_dim),
.linear_y_w = {"gr_lin_y_w", kGriffinDim, kGriffinDim}, qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim,
.linear_y_biases = {"gr_lin_y_b", 1, kGriffinDim}, config.model_dim),
.linear_out_w = {"gr_lin_out_w", kGriffinDim, kGriffinDim}, attention_output_biases(
.linear_out_biases = {"gr_lin_out_b", 1, kGriffinDim}, "attn_ob", 1,
.conv_w = {"gr_conv_w", kConv1dWidth, kGriffinDim}, config.softmax_attn_output_biases ? config.model_dim : 0),
.conv_biases = {"gr_conv_b", 1, kGriffinDim}, griffin(
.gate_w = {"gr_gate_w", 2 * kGriffinDim, kGriffinDim / kHeads}, {.linear_x_w = {"gr_lin_x_w", config.griffin_dim,
.gate_biases = {"gr_gate_b", 1, kGriffinDim * 2}, config.griffin_dim},
.a = {"gr_a", 1, kGriffinDim}}), .linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim},
.linear_y_w = {"gr_lin_y_w", config.griffin_dim,
config.griffin_dim},
.linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim},
.linear_out_w = {"gr_lin_out_w", config.griffin_dim,
config.griffin_dim},
.linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim},
.conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim},
.conv_biases = {"gr_conv_b", 1, config.griffin_dim},
.gate_w = {"gr_gate_w", 2 * config.griffin_dim,
config.griffin_dim / config.heads},
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
.a = {"gr_a", 1, config.griffin_dim}}),
// MultiHeadDotProductAttention. // MultiHeadDotProductAttention.
vit({.attn_out_w = {"attn_out_w", kHeads * kQKVDim, kModelDim}, vit({.attn_out_w = {"attn_out_w", config.heads * config.qkv_dim,
.attn_out_b = {"attn_out_b", 1, kModelDim}, config.model_dim},
.qkv_einsum_w = {"qkv_ein_w", (kHeads + 2 * kKVHeads) * kQKVDim, .attn_out_b = {"attn_out_b", 1, config.model_dim},
kModelDim}, .qkv_einsum_w = {"qkv_ein_w",
.qkv_einsum_b = {"qkv_ein_b", (kHeads + 2 * kKVHeads), kQKVDim}, (config.heads + 2 * config.kv_heads) *
.linear_0_w = {"linear_0_w", kModelDim, kFFHiddenDim}, config.qkv_dim,
.linear_0_b = {"linear_0_b", 1, kFFHiddenDim}, config.model_dim},
.linear_1_w = {"linear_1_w", kFFHiddenDim, kModelDim}, .qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
.linear_1_b = {"linear_1_b", 1, kModelDim}, config.qkv_dim},
.layer_norm_0_bias = {"ln_0_bias", 1, kModelDim}, .linear_0_w = {"linear_0_w", config.model_dim,
.layer_norm_0_scale = {"ln_0_scale", 1, kModelDim}, config.ff_hidden_dim},
.layer_norm_1_bias = {"ln_1_bias", 1, kModelDim}, .linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
.layer_norm_1_scale = {"ln_1_scale", 1, kModelDim}}), .linear_1_w = {"linear_1_w", config.ff_hidden_dim,
gating_einsum_w("gating_ein", 2 * kFFHiddenDim, kModelDim), config.model_dim},
gating_einsum_w1("gating1_w", kFFHiddenDim, kModelDim), .linear_1_b = {"linear_1_b", 1, config.model_dim},
gating_einsum_w2("gating2_w", kFFHiddenDim, kModelDim), .layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
linear_w("linear_w", kModelDim, kFFHiddenDim), .layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
pre_attention_norm_scale("pre_att_ns", 1, kModelDim), .layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim},
pre_ffw_norm_scale("pre_ff_ns", 1, kModelDim), .layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}),
gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim,
config.model_dim),
gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim),
gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim),
linear_w("linear_w", config.model_dim, config.ff_hidden_dim),
pre_attention_norm_scale("pre_att_ns", 1, config.model_dim),
pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim),
post_attention_norm_scale( post_attention_norm_scale(
"post_att_ns", 1, kPostNorm == PostNormType::Scale ? kModelDim : 0), "post_att_ns", 1,
post_ffw_norm_scale("post_ff_ns", 1, config.post_norm == PostNormType::Scale ? config.model_dim : 0),
kPostNorm == PostNormType::Scale ? kModelDim : 0), post_ffw_norm_scale(
ffw_gating_biases("ffw_gat_b", 1, kFFBiases ? 2 * kFFHiddenDim : 0), "post_ff_ns", 1,
ffw_output_biases("ffw_out_b", 1, kFFBiases ? kModelDim : 0), config.post_norm == PostNormType::Scale ? config.model_dim : 0),
att_weights("att_w", kModelDim, kHeads * kQKVDim) ffw_gating_biases("ffw_gat_b", 1,
{} config.ff_biases ? 2 * config.ff_hidden_dim : 0),
~CompressedLayer() = default; ffw_output_biases("ffw_out_b", 1,
config.ff_biases ? config.model_dim : 0),
att_weights("att_w", config.model_dim, config.heads * config.qkv_dim),
layer_config(config) {}
~LayerWeightsPtrs() = default;
using Weight = typename TConfig::Weight;
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that // If weights are f32, also f32; otherwise at least bf16. Useful for ops that
// do not yet support smaller compressed types, or require at least bf16. When // do not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32. // weights are f32, we also want such tensors to be f32.
@ -113,25 +135,6 @@ struct CompressedLayer {
hwy::If<hwy::IsSame<Weight, double>(), double, hwy::If<hwy::IsSame<Weight, double>(), double,
hwy::If<IsF32<Weight>(), float, BF16>>>; hwy::If<IsF32<Weight>(), float, BF16>>>;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kQKVDim = TConfig::kQKVDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumWSize =
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
static constexpr size_t kQKVEinsumBSize = (kHeads + 2 * kKVHeads) * kQKVDim;
// 2x for (gelu gating vector, gated vector)
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr bool kFFBiases = TConfig::kFFBiases;
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
static constexpr size_t kAOBiasDim =
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
static constexpr size_t kGriffinDim =
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
template <class T> template <class T>
using ArrayT = MatPtrT<T>; using ArrayT = MatPtrT<T>;
@ -195,28 +198,32 @@ struct CompressedLayer {
// Reshaped attention; not loaded from disk via ForEachTensor. // Reshaped attention; not loaded from disk via ForEachTensor.
ArrayT<Weight> att_weights; ArrayT<Weight> att_weights;
const LayerConfig& layer_config;
// Initializes att_weights from attn_vec_einsum_w, hence this must be called // Initializes att_weights from attn_vec_einsum_w, hence this must be called
// after loading weights via ForEachTensor. // after loading weights via ForEachTensor.
// TODO: update compression/convert_weights to bake this in. // TODO: update compression/convert_weights to bake this in.
void Reshape(MatStorage& storage) { void Reshape(MatStorage* storage) {
if (attn_vec_einsum_w.data() == nullptr) return; if (attn_vec_einsum_w.data() == nullptr) return;
constexpr size_t kModelDim = TConfig::kModelDim; const size_t model_dim = layer_config.model_dim;
constexpr size_t kHeads = TConfig::kHeads; const size_t heads = layer_config.heads;
constexpr size_t kQKVDim = TConfig::kQKVDim; const size_t qkv_dim = layer_config.qkv_dim;
// Would have to implement a CompressTraits::Copy for NUQ. // TODO: implement a CompressTraits::Copy for NUQ.
static_assert(!hwy::IsSame<Weight, NuqStream>()); // static_assert(!hwy::IsSame<Weight, NuqStream>());
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
storage.Allocate(); if (storage != nullptr) {
att_weights.SetPtr(storage); storage->Allocate();
for (size_t m = 0; m < kModelDim; ++m) { att_weights.SetPtr(*storage);
Weight* HWY_RESTRICT out_row = att_weights.data() + m * kHeads * kQKVDim; }
for (size_t h = 0; h < kHeads; ++h) { for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes( hwy::CopyBytes(
attn_vec_einsum_w.data() + h * kModelDim * kQKVDim + m * kQKVDim, attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * kQKVDim, kQKVDim * sizeof(Weight)); out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
} }
} }
att_weights.set_scale(attn_vec_einsum_w.scale()); att_weights.set_scale(attn_vec_einsum_w.scale());
@ -235,11 +242,11 @@ struct CompressedLayer {
} }
template <class Func> template <class Func>
static void ForEachTensor(const std::vector<CompressedLayer<TConfig>*>& ptrs, static void ForEachTensor(const std::vector<LayerWeightsPtrs<Weight>*>& ptrs,
int layer_idx, ForEachType fet, Func func, int layer_idx, ForEachType fet, Func func,
char sep = ' ', int sep_index = -1) { char sep = ' ', int sep_index = -1) {
MatPtr* tensors[ptrs.size()]; MatPtr* tensors[ptrs.size()];
auto type = TConfig::kLayerConfig[layer_idx]; auto type = ptrs[0]->layer_config.type;
if (type == LayerAttentionType::kVit) { if (type == LayerAttentionType::kVit) {
// MHA. // MHA.
GEMMA_CALL_FUNC(vit.attn_out_w); GEMMA_CALL_FUNC(vit.attn_out_w);
@ -296,17 +303,17 @@ struct CompressedLayer {
GEMMA_CALL_FUNC(pre_attention_norm_scale); GEMMA_CALL_FUNC(pre_attention_norm_scale);
GEMMA_CALL_FUNC(pre_ffw_norm_scale); GEMMA_CALL_FUNC(pre_ffw_norm_scale);
if (TConfig::kPostNorm == PostNormType::Scale) { if (ptrs[0]->layer_config.post_norm == PostNormType::Scale) {
GEMMA_CALL_FUNC(post_attention_norm_scale); GEMMA_CALL_FUNC(post_attention_norm_scale);
GEMMA_CALL_FUNC(post_ffw_norm_scale); GEMMA_CALL_FUNC(post_ffw_norm_scale);
} }
if (TConfig::kFFBiases) { if (ptrs[0]->layer_config.ff_biases) {
GEMMA_CALL_FUNC(ffw_gating_biases); GEMMA_CALL_FUNC(ffw_gating_biases);
GEMMA_CALL_FUNC(ffw_output_biases); GEMMA_CALL_FUNC(ffw_output_biases);
} }
if (TConfig::kSoftmaxAttnOutputBiases && if (ptrs[0]->layer_config.softmax_attn_output_biases &&
type == LayerAttentionType::kGemma) { type == LayerAttentionType::kGemma) {
GEMMA_CALL_FUNC(attention_output_biases); GEMMA_CALL_FUNC(attention_output_biases);
} }
@ -322,47 +329,45 @@ struct CompressedLayer {
// Allocates memory for all the tensors in the layer. // Allocates memory for all the tensors in the layer.
// Note that this is slow and only used for a stand-alone layer. // Note that this is slow and only used for a stand-alone layer.
void Allocate() { void Allocate(std::vector<MatStorage>& layer_storage) {
layer_storage.clear(); ForEachTensor(
ForEachTensor({this}, /*layer_idx=*/0, ForEachType::kInitNoToc, {this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
[this](const char* name, hwy::Span<MatPtr*> tensors) { [&layer_storage](const char* name, hwy::Span<MatPtr*> tensors) {
this->layer_storage.emplace_back(*tensors[0]); layer_storage.emplace_back(*tensors[0]);
layer_storage.back().Allocate(); layer_storage.back().Allocate();
tensors[0]->SetPtr(layer_storage.back()); tensors[0]->SetPtr(layer_storage.back());
}); });
} }
// Storage for all the matrices and vectors. Only used for a stand-alone
// layer. For a model, the CompressedWeights::model_storage is used instead.
std::vector<MatStorage> layer_storage;
}; };
template <class TConfig> template <class Weight>
struct CompressedWeights { struct ModelWeightsPtrs {
explicit CompressedWeights(hwy::ThreadPool& pool) ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool)
: embedder_input_embedding("c_embedding", TConfig::kVocabSize, : embedder_input_embedding("c_embedding", config.vocab_size,
TConfig::kModelDim), config.model_dim),
final_norm_scale("c_final_norm", 1, TConfig::kModelDim), final_norm_scale("c_final_norm", 1, config.model_dim),
vit_encoder_norm_bias("enc_norm_bias", 1, vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
TConfig::VitConfig::kModelDim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
TConfig::VitConfig::kModelDim),
vit_img_embedding_bias("img_emb_bias", 1,
TConfig::VitConfig::kModelDim),
vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3, vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3,
TConfig::VitConfig::kModelDim), config.vit_model_dim),
vit_img_pos_embedding("img_pos_emb", 256, vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
TConfig::VitConfig::kModelDim), vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_bias("img_head_bias", 1, TConfig::kModelDim), vit_img_head_kernel("img_head_kernel", config.vit_model_dim,
vit_img_head_kernel("img_head_kernel", TConfig::VitConfig::kModelDim, config.model_dim),
TConfig::kModelDim), scale_names(config.scale_names),
scale_names({"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", weights_config(config) {
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}) {} c_layers.reserve(config.layer_configs.size());
for (const auto& layer_config : config.layer_configs) {
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
}
for (const auto& layer_config : config.vit_layer_configs) {
vit_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
}
}
~CompressedWeights() = default; ~ModelWeightsPtrs() = default;
using WeightF32OrBF16 = typename LayerWeightsPtrs<Weight>::WeightF32OrBF16;
using Weight = typename TConfig::Weight;
using WeightF32OrBF16 = typename CompressedLayer<TConfig>::WeightF32OrBF16;
using WeightF32OrInputT = hwy::If<hwy::IsSame<WeightF32OrBF16, BF16>(), using WeightF32OrInputT = hwy::If<hwy::IsSame<WeightF32OrBF16, BF16>(),
EmbedderInputT, WeightF32OrBF16>; EmbedderInputT, WeightF32OrBF16>;
@ -380,49 +385,73 @@ struct CompressedWeights {
MatPtrT<float> vit_img_head_bias; MatPtrT<float> vit_img_head_bias;
MatPtrT<WeightF32OrBF16> vit_img_head_kernel; MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
// Storage for all the matrices and vectors.
std::vector<MatStorage> model_storage;
std::unordered_set<std::string> scale_names; std::unordered_set<std::string> scale_names;
CompressedLayer<TConfig> c_layers[TConfig::kLayers]; const ModelConfig& weights_config;
CompressedLayer<typename TConfig::VitConfig>
vit_layers[TConfig::VitConfig::kLayers];
// Called by weights.cc after ForEachTensor. std::vector<LayerWeightsPtrs<Weight>> c_layers;
void Reshape(hwy::ThreadPool& pool) { std::vector<LayerWeightsPtrs<Weight>> vit_layers;
// Called by weights.cc after Loading, before att_w has been allocated.
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool,
std::vector<MatStorage>& model_storage) {
size_t storage_index = model_storage.size(); size_t storage_index = model_storage.size();
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (auto& layer : c_layers) {
model_storage.emplace_back(GetLayer(layer)->att_weights); model_storage.emplace_back(layer.att_weights);
} }
pool.Run(0, TConfig::kLayers, pool.Run(0, c_layers.size(),
[this, storage_index](uint64_t layer, size_t /*thread*/) { [this, &model_storage, storage_index](uint64_t layer,
GetLayer(layer)->Reshape(model_storage[storage_index + layer]); size_t /*thread*/) {
GetLayer(layer)->Reshape(&model_storage[storage_index + layer]);
});
}
// For when the storage has already been allocated.
void CopyWithTranspose(hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Reshape(nullptr);
}); });
} }
void ZeroInit() { void ZeroInit() {
embedder_input_embedding.ZeroInit(); embedder_input_embedding.ZeroInit();
final_norm_scale.ZeroInit(); final_norm_scale.ZeroInit();
for (int i = 0; i < TConfig::kLayers; ++i) { for (size_t i = 0; i < c_layers.size(); ++i) {
c_layers[i].ZeroInit(i); c_layers[i].ZeroInit(i);
} }
} }
const CompressedLayer<TConfig>* GetLayer(size_t layer) const { const LayerWeightsPtrs<Weight>* GetLayer(size_t layer) const {
return &c_layers[layer]; return &c_layers[layer];
} }
CompressedLayer<TConfig>* GetLayer(size_t layer) { return &c_layers[layer]; } LayerWeightsPtrs<Weight>* GetLayer(size_t layer) { return &c_layers[layer]; }
const CompressedLayer<typename TConfig::VitConfig>* GetVitLayer( const LayerWeightsPtrs<Weight>* GetVitLayer(size_t layer) const {
size_t layer) const {
return &vit_layers[layer]; return &vit_layers[layer];
} }
CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(size_t layer) { LayerWeightsPtrs<Weight>* GetVitLayer(size_t layer) {
return &vit_layers[layer]; return &vit_layers[layer];
} }
void Allocate(std::vector<MatStorage>& model_storage, hwy::ThreadPool& pool) {
std::vector<MatPtr*> model_toc;
ForEachTensor(
{this}, ForEachType::kInitNoToc,
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
model_toc.push_back(tensors[0]);
model_storage.emplace_back(*tensors[0]);
});
// Allocate in parallel using the pool.
pool.Run(0, model_toc.size(),
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
// model_storage may have had content before we started.
size_t idx = task + model_storage.size() - model_toc.size();
model_storage[idx].Allocate();
model_toc[task]->SetPtr(model_storage[idx]);
});
}
// Copies the data from other to *this. // Copies the data from other to *this.
void CopyFrom(const CompressedWeights<TConfig>& other) { void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
ForEachTensor({this, const_cast<CompressedWeights<TConfig>*>(&other)}, ForEachTensor({this, const_cast<ModelWeightsPtrs<Weight>*>(&other)},
ForEachType::kIgnoreNulls, ForEachType::kIgnoreNulls,
[](const char*, hwy::Span<MatPtr*> tensors) { [](const char*, hwy::Span<MatPtr*> tensors) {
hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(), hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(),
@ -448,16 +477,14 @@ struct CompressedWeights {
++scale_pos; ++scale_pos;
} }
}); });
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales); HWY_ASSERT(scale_pos == weights_config.num_tensor_scales);
} }
template <class Func> template <class Func>
static void ForEachTensor( static void ForEachTensor(const std::vector<ModelWeightsPtrs<Weight>*>& ptrs,
const std::vector<CompressedWeights<TConfig>*>& ptrs, ForEachType fet, ForEachType fet, Func func) {
Func func) { std::vector<LayerWeightsPtrs<Weight>*> layers(ptrs.size());
std::vector<CompressedLayer<TConfig>*> layers(ptrs.size()); std::vector<LayerWeightsPtrs<Weight>*> vit_layers(ptrs.size());
std::vector<CompressedLayer<typename TConfig::VitConfig>*> vit_layers(
ptrs.size());
MatPtr* tensors[ptrs.size()]; MatPtr* tensors[ptrs.size()];
// Variables used by GEMMA_CALL_FUNC. // Variables used by GEMMA_CALL_FUNC.
int layer_idx = -1; int layer_idx = -1;
@ -465,7 +492,7 @@ struct CompressedWeights {
int sep_index = -1; int sep_index = -1;
GEMMA_CALL_FUNC(embedder_input_embedding); GEMMA_CALL_FUNC(embedder_input_embedding);
GEMMA_CALL_FUNC(final_norm_scale); GEMMA_CALL_FUNC(final_norm_scale);
if constexpr (TConfig::VitConfig::kLayers > 0) { if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
// Vit parts. // Vit parts.
GEMMA_CALL_FUNC(vit_encoder_norm_bias); GEMMA_CALL_FUNC(vit_encoder_norm_bias);
GEMMA_CALL_FUNC(vit_encoder_norm_scale); GEMMA_CALL_FUNC(vit_encoder_norm_scale);
@ -476,90 +503,108 @@ struct CompressedWeights {
GEMMA_CALL_FUNC(vit_img_head_kernel); GEMMA_CALL_FUNC(vit_img_head_kernel);
} }
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
for (int i = 0; i < ptrs.size(); ++i) { for (int i = 0; i < ptrs.size(); ++i) {
layers[i] = ptrs[i]->GetLayer(layer_idx); layers[i] = ptrs[i]->GetLayer(layer_idx);
} }
CompressedLayer<TConfig>::ForEachTensor(layers, layer_idx, fet, func); LayerWeightsPtrs<Weight>::ForEachTensor(layers, layer_idx, fet, func);
} }
// Vit layers. Not supported for compress_weights. // Vit layers. Not supported for compress_weights.
if constexpr (TConfig::VitConfig::kLayers > 0) { if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
for (int layer_idx = 0; layer_idx < TConfig::VitConfig::kLayers; for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size();
++layer_idx) { ++layer_idx) {
auto type = TConfig::VitConfig::kLayerConfig[layer_idx]; auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type;
HWY_ASSERT(type == LayerAttentionType::kVit); HWY_ASSERT(type == LayerAttentionType::kVit);
for (int i = 0; i < ptrs.size(); ++i) { for (int i = 0; i < ptrs.size(); ++i) {
vit_layers[i] = ptrs[i]->GetVitLayer(layer_idx); vit_layers[i] = ptrs[i]->GetVitLayer(layer_idx);
} }
CompressedLayer<typename TConfig::VitConfig>::ForEachTensor( LayerWeightsPtrs<Weight>::ForEachTensor(vit_layers, layer_idx, fet,
vit_layers, layer_idx, fet, func); func);
} }
} }
} }
}; };
#undef GEMMA_CALL_FUNC #undef GEMMA_CALL_FUNC
// Pair of configs for the compressed and uncompressed weights.
template <class CConfig, class UCConfig>
struct ConfigPair {
using uc = UCConfig;
using c = CConfig;
};
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Interface // Interface
template <typename TConfig> class ModelWeightsStorage {
struct AllocateCompressedWeights { public:
ByteStorageT operator()(hwy::ThreadPool& pool) const { ModelWeightsStorage() = default;
using TWeights = CompressedWeights<TConfig>; ~ModelWeightsStorage() = default;
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get()); BlobError Load(const Path& weights, Model model_type, Type weight_type,
new (weights) TWeights(pool); hwy::ThreadPool& pool);
std::vector<MatPtr*> model_toc; void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) {
auto& model_storage = weights->model_storage; Allocate(ConfigFromModel(model_type), weight_type, pool);
TWeights::ForEachTensor(
{weights}, ForEachType::kInitNoToc,
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
model_toc.push_back(tensors[0]);
model_storage.emplace_back(*tensors[0]);
});
// Allocate in parallel using the pool.
pool.Run(0, model_storage.size(),
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
model_storage[task].Allocate();
model_toc[task]->SetPtr(model_storage[task]);
});
return weights_u8;
} }
}; void Allocate(const ModelConfig& config, Type weight_type,
hwy::ThreadPool& pool);
void RandInit(std::mt19937& gen);
void ZeroInit();
void GetOrApplyScales(std::vector<float>& scales);
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool);
void CopyWithTranspose(hwy::ThreadPool& pool);
void LogWeightStats();
const ModelConfig& Config() const { return config_; }
template <typename TConfig> template <typename T>
struct ZeroInitCompressedWeights { ModelWeightsPtrs<T>* GetWeightsOfType() const {
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { if constexpr (IsSfpStream<T>()) {
CompressedWeights<TConfig>& weights = return sfp_weights_.get();
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get()); } else if constexpr (IsF32<T>()) {
weights.ZeroInit(); return float_weights_.get();
} else if constexpr (IsBF16<T>()) {
return bf16_weights_.get();
} else if constexpr (IsNuqStream<T>()) {
return nuq_weights_.get();
} else {
return HWY_ABORT("Unsupported type.");
} }
};
template <typename TConfig>
struct ReshapeCompressedWeights {
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
CompressedWeights<TConfig>& weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
weights.Reshape(pool);
} }
template <template <typename T> class FuncT, typename... TArgs>
decltype(auto) CallForModelWeightT(TArgs&&... args) {
if (HWY_LIKELY(sfp_weights_))
return FuncT<SfpStream>()(*sfp_weights_, std::forward<TArgs>(args)...);
if (bf16_weights_)
return FuncT<BF16>()(*bf16_weights_, std::forward<TArgs>(args)...);
if (nuq_weights_)
return FuncT<NuqStream>()(*nuq_weights_, std::forward<TArgs>(args)...);
if (float_weights_)
return FuncT<float>()(*float_weights_, std::forward<TArgs>(args)...);
return HWY_ABORT("No weights loaded.");
}
template <template <typename T> class FuncT, typename... TArgs>
decltype(auto) CallForModelWeight(TArgs&&... args) {
if (HWY_LIKELY(sfp_weights_))
return FuncT<SfpStream>()(*this, std::forward<TArgs>(args)...);
if (bf16_weights_)
return FuncT<BF16>()(*this, std::forward<TArgs>(args)...);
if (nuq_weights_)
return FuncT<NuqStream>()(*this, std::forward<TArgs>(args)...);
if (float_weights_)
return FuncT<float>()(*this, std::forward<TArgs>(args)...);
return HWY_ABORT("No weights loaded.");
}
private:
void CreateForType(Type weight_type, hwy::ThreadPool& pool);
ModelConfig config_;
// To eliminate type templates, we hold a pointer to one of each weight type
// and dispatch to whichever is non-null.
std::unique_ptr<ModelWeightsPtrs<float>> float_weights_;
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
// Storage for all the matrices and vectors.
std::vector<MatStorage> model_storage_;
}; };
// TODO: also add RandInitCompressedWeights
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool);
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_

View File

@ -115,7 +115,7 @@ void TestMatVecAdd() {
FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add); FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add);
FloatPtr actual_out = hwy::AllocateAligned<float>(kOuter); FloatPtr actual_out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add && expected_out && actual_out); HWY_ASSERT(vec && add && expected_out && actual_out);
MatVecAdd<kOuter, kInner>(*mat, 0, vec.get(), add.get(), actual_out.get(), MatVecAdd(*mat, 0, kOuter, kInner, vec.get(), add.get(), actual_out.get(),
pool); pool);
AssertClose<kOuter>(actual_out, expected_out); AssertClose<kOuter>(actual_out, expected_out);
} }
@ -135,9 +135,8 @@ void TestTwoMatVecAdd() {
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter); FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1); expected_out1 && actual_out1);
TwoMatVecAdd<kOuter, kInner>(*mat0, *mat1, 0, vec.get(), add0.get(), TwoMatVecAdd(*mat0, *mat1, 0, kOuter, kInner, vec.get(), add0.get(),
add1.get(), actual_out0.get(), actual_out1.get(), add1.get(), actual_out0.get(), actual_out1.get(), pool);
pool);
AssertClose<kOuter>(actual_out0, expected_out0); AssertClose<kOuter>(actual_out0, expected_out0);
AssertClose<kOuter>(actual_out1, expected_out1); AssertClose<kOuter>(actual_out1, expected_out1);
} }
@ -156,9 +155,8 @@ void TestTwoOfsMatVecAddLoop() {
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter); FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 && HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
expected_out1 && actual_out1); expected_out1 && actual_out1);
TwoOfsMatVecAddLoop<kOuter, kInner>(*mat, 0, 0, vec.get(), add0.get(), TwoOfsMatVecAddLoop(*mat, 0, 0, kOuter, kInner, vec.get(), add0.get(),
add1.get(), actual_out0.get(), add1.get(), actual_out0.get(), actual_out1.get());
actual_out1.get());
AssertClose<kOuter>(actual_out0, expected_out0); AssertClose<kOuter>(actual_out0, expected_out0);
AssertClose<kOuter>(actual_out1, expected_out1); AssertClose<kOuter>(actual_out1, expected_out1);
} }

View File

@ -47,10 +47,10 @@ namespace hn = hwy::HWY_NAMESPACE;
// Simple version without tiling nor threading, but two offsets/outputs and // Simple version without tiling nor threading, but two offsets/outputs and
// always with addition. // always with addition.
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT, template <typename ArrayT, typename VecT, typename AddT>
typename AddT>
HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0, HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
const size_t mat_ofs1, const size_t mat_ofs1, const size_t outer,
const size_t inner,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add0, const AddT* HWY_RESTRICT add0,
const AddT* HWY_RESTRICT add1, const AddT* HWY_RESTRICT add1,
@ -58,13 +58,13 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
float* HWY_RESTRICT out1) { float* HWY_RESTRICT out1) {
PROFILER_ZONE("TwoOfsMatVecAddLoop"); PROFILER_ZONE("TwoOfsMatVecAddLoop");
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { for (size_t idx_row = 0; idx_row < outer; ++idx_row) {
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner; const size_t row_ofs0 = mat_ofs0 + (idx_row)*inner;
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner; const size_t row_ofs1 = mat_ofs1 + (idx_row)*inner;
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) + out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
Dot(mat, row_ofs0, vec_aligned, kInner); Dot(mat, row_ofs0, vec_aligned, inner);
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) + out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
Dot(mat, row_ofs1, vec_aligned, kInner); Dot(mat, row_ofs1, vec_aligned, inner);
} }
} }
@ -84,6 +84,14 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
return kRowsPerStrip; return kRowsPerStrip;
} }
HWY_INLINE size_t RowsPerStrip(const size_t outer) {
// Aim for 128 work items to reduce pool overhead. Must be at least one
// vector; prefer a power of two for faster division.
constexpr size_t kLanes = hn::ScalableTag<float>().MaxLanes();
return outer < 128 ? kLanes
: HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(outer / 128));
}
namespace detail { namespace detail {
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product // For each i = [0, num_rows), compute partial (length `num_cols`) dot product
@ -161,63 +169,63 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
// Stores dot products of rows with `vec_aligned` + add the values from `add` // Stores dot products of rows with `vec_aligned` + add the values from `add`
// (if kAdd), then stores them to `out`. // (if kAdd), then stores them to `out`.
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT, template <bool kAdd, typename ArrayT, typename VecT, typename AddT>
typename VecT, typename AddT>
HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT const vec_aligned, const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add, const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) { float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
PROFILER_ZONE("MatVecAdd"); PROFILER_ZONE("MatVecAdd");
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>(); const size_t rows_per_strip = RowsPerStrip(outer);
constexpr size_t kNumStrips = kOuter / kRowsPerStrip; const size_t num_strips = outer / rows_per_strip;
// For each entire strip. // For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("MatVec.lambda"); PROFILER_ZONE("MatVec.lambda");
const size_t r0 = strip * kRowsPerStrip; const size_t r0 = strip * rows_per_strip;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0, detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, inner, r0,
kRowsPerStrip, vec_aligned, add, rows_per_strip, vec_aligned, add,
out + r0); out + r0);
}); });
// Remaining rows // Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip; const size_t r0 = num_strips * rows_per_strip;
if (r0 < kOuter) { if (r0 < outer) {
PROFILER_ZONE("MatVec remainder"); PROFILER_ZONE("MatVec remainder");
const size_t num_rows = kOuter - r0; const size_t num_rows = outer - r0;
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0, detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, inner, r0, num_rows,
num_rows, vec_aligned, add, out + r0); vec_aligned, add, out + r0);
} }
} }
// With addition // With addition
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT, template <typename ArrayT, typename VecT, typename AddT>
typename AddT>
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT const vec_aligned, const VecT* HWY_RESTRICT const vec_aligned,
const AddT* HWY_RESTRICT const add, const AddT* HWY_RESTRICT const add,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) { float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
return MatVecT</*kAdd=*/true, kOuter, kInner>(mat, mat_ofs, vec_aligned, add, return MatVecT</*kAdd=*/true>(mat, mat_ofs, outer, inner, vec_aligned, add,
out, pool); out, pool);
} }
// Without addition // Without addition
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT> template <typename ArrayT, typename VecT>
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT const vec_aligned, const VecT* HWY_RESTRICT const vec_aligned,
float* HWY_RESTRICT out, hwy::ThreadPool& pool) { float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
MatVecT</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned, MatVecT</*kAdd=*/false>(mat, mat_ofs, outer, inner, vec_aligned,
/*add=*/static_cast<VecT*>(nullptr), /*add=*/static_cast<VecT*>(nullptr), out, pool);
out, pool);
} }
// Two matrices, same vector // Two matrices, same vector
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT1, template <bool kAdd, typename ArrayT1, typename ArrayT2, typename VecT,
typename ArrayT2, typename VecT, typename AddT> typename AddT>
HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1, HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
const size_t mat_ofs, const size_t mat_ofs, size_t outer, size_t inner,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
const AddT* HWY_RESTRICT add0, const AddT* HWY_RESTRICT add0,
const AddT* HWY_RESTRICT add1, const AddT* HWY_RESTRICT add1,
@ -226,56 +234,56 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
PROFILER_ZONE("TwoMatVecAdd"); PROFILER_ZONE("TwoMatVecAdd");
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>(); const size_t rows_per_strip = RowsPerStrip(outer);
constexpr size_t kNumStrips = kOuter / kRowsPerStrip; const size_t num_strips = outer / rows_per_strip;
// For each entire strip. // For each entire strip.
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
PROFILER_ZONE("TwoMatVec.lambda"); PROFILER_ZONE("TwoMatVec.lambda");
const size_t r0 = strip * kRowsPerStrip; const size_t r0 = strip * rows_per_strip;
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0, detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, inner, r0,
kRowsPerStrip, vec_aligned, add0, rows_per_strip, vec_aligned, add0,
out0 + r0); out0 + r0);
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0, detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, inner, r0,
kRowsPerStrip, vec_aligned, add1, rows_per_strip, vec_aligned, add1,
out1 + r0); out1 + r0);
}); });
// Remaining rows // Remaining rows
const size_t r0 = kNumStrips * kRowsPerStrip; const size_t r0 = num_strips * rows_per_strip;
if (r0 < kOuter) { if (r0 < outer) {
PROFILER_ZONE("TwoMatVec remainder"); PROFILER_ZONE("TwoMatVec remainder");
const size_t num_rows = kOuter - r0; const size_t num_rows = outer - r0;
detail::FullDotProductsForStrip<kAdd>( detail::FullDotProductsForStrip<kAdd>(
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0); df, mat0, mat_ofs, inner, r0, num_rows, vec_aligned, add0, out0 + r0);
detail::FullDotProductsForStrip<kAdd>( detail::FullDotProductsForStrip<kAdd>(
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0); df, mat1, mat_ofs, inner, r0, num_rows, vec_aligned, add1, out1 + r0);
} }
} }
// With addition // With addition
template <size_t kOuter, size_t kInner, typename ArrayT1, typename ArrayT2, template <typename ArrayT1, typename ArrayT2, typename VecT, typename AddT>
typename VecT, typename AddT>
HWY_NOINLINE void TwoMatVecAdd( HWY_NOINLINE void TwoMatVecAdd(
const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs, const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs,
const size_t outer, const size_t inner,
const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0, const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0,
const AddT* HWY_RESTRICT add1, float* HWY_RESTRICT out0, const AddT* HWY_RESTRICT add1, float* HWY_RESTRICT out0,
float* HWY_RESTRICT out1, hwy::ThreadPool& pool) { float* HWY_RESTRICT out1, hwy::ThreadPool& pool) {
return TwoMatVecT</*kAdd=*/true, kOuter, kInner>( return TwoMatVecT</*kAdd=*/true>(mat0, mat1, mat_ofs, outer, inner,
mat0, mat1, mat_ofs, vec_aligned, add0, add1, out0, out1, pool); vec_aligned, add0, add1, out0, out1, pool);
} }
// Without addition // Without addition
template <size_t kOuter, size_t kInner, typename ArrayT1, typename ArrayT2, template <typename ArrayT1, typename ArrayT2, typename VecT>
typename VecT>
HWY_NOINLINE void TwoMatVec(const ArrayT1& mat0, const ArrayT2& mat1, HWY_NOINLINE void TwoMatVec(const ArrayT1& mat0, const ArrayT2& mat1,
const size_t mat_ofs, const size_t mat_ofs, const size_t outer,
const size_t inner,
const VecT* HWY_RESTRICT vec_aligned, const VecT* HWY_RESTRICT vec_aligned,
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1, float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
TwoMatVecT</*kAdd=*/false, kOuter, kInner, ArrayT1, ArrayT2, VecT, VecT>( TwoMatVecT</*kAdd=*/false, ArrayT1, ArrayT2, VecT, VecT>(
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr, mat0, mat1, mat_ofs, outer, inner, vec_aligned, /*add0=*/nullptr,
out0, out1, pool); /*add1=*/nullptr, out0, out1, pool);
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -21,11 +21,11 @@
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <array>
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include <vector>
#include "compression/compress.h" #include "compression/compress.h"
#include "util/basics.h" // TokenAndProb #include "util/basics.h" // TokenAndProb
@ -673,9 +673,8 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
return max_index; return max_index;
} }
template <size_t k> HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> std::vector<float>& top_k, float temperature) {
create_distribution(std::array<float, k>& top_k, float temperature) {
HWY_ASSERT(temperature >= 0.0f); HWY_ASSERT(temperature >= 0.0f);
if (temperature == 0.0f) { if (temperature == 0.0f) {
// Temperature == 0 is a special case which always returns the argmax (0). // Temperature == 0 is a special case which always returns the argmax (0).
@ -696,16 +695,16 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k)); return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
} }
template <size_t k, typename TAcceptToken> template <typename TAcceptToken>
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t vocab_size, const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) { std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
static_assert(k != 0, ""); HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size); HWY_ASSERT(k <= vocab_size);
// TODO: Optimize, potentially using new VQSort PartialSort. // TODO: Optimize, potentially using new VQSort PartialSort.
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1] // Sorted from highest [0], to lowest [k-1]
top_k.fill(-std::numeric_limits<float>::infinity()); std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
std::array<int, k> indices{}; std::vector<int> indices(k);
size_t num_accepted = 0; size_t num_accepted = 0;
for (size_t i = 0; i < vocab_size; ++i) { for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1]) continue; if (probabilities[i] < top_k[k - 1]) continue;
@ -727,7 +726,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
} }
} }
HWY_ASSERT(k <= num_accepted); HWY_ASSERT(k <= num_accepted);
return indices[create_distribution<k>(top_k, temperature)(gen)]; return indices[create_distribution(top_k, temperature)(gen)];
} }
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -387,8 +387,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
} }
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
using Config = ConfigGemma2_9B<float>; ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
int dim_qkv = Config::kQKVDim; int dim_qkv = config.layer_configs[0].qkv_dim;
RowVectorBatch<float> x(1, dim_qkv); RowVectorBatch<float> x(1, dim_qkv);
std::mt19937 gen; std::mt19937 gen;
@ -400,15 +400,15 @@ void TestRopeAndMulBy() {
x.All()[i] = random_float(); x.All()[i] = random_float();
} }
const float qmul = ChooseQueryScale<Config>(); const float qmul = ChooseQueryScale(config);
const float kmul = 1.0; const float kmul = 1.0;
std::vector<float> qexpected(dim_qkv); std::vector<float> qexpected(dim_qkv);
std::vector<float> qactual(dim_qkv); std::vector<float> qactual(dim_qkv);
std::vector<float> kexpected(dim_qkv); std::vector<float> kexpected(dim_qkv);
std::vector<float> kactual(dim_qkv); std::vector<float> kactual(dim_qkv);
RowVectorBatch<float> inv_timescale = RowVectorBatch<float> inv_timescale = gcpp::Activations::CreateInvTimescale(
gcpp::Activations::CreateInvTimescale<Config>(); config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
// Assert VectorizedRope computation is same as regular rope at different pos. // Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) { for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings // Rope'd Q embeddings
@ -571,20 +571,20 @@ void TestSampleTopK() {
float temperature = 1.0f; float temperature = 1.0f;
// SampleTopK<1> should return the argmax. // SampleTopK<1> should return the argmax.
std::function<bool(int, float)> accept_token; std::function<bool(int, float)> accept_token;
int sample = SampleTopK<1>(logits.data(), kSize, gen, temperature, int sample =
accept_token); SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
EXPECT_EQ(sample, 51); // Last is largest. EXPECT_EQ(sample, 51); // Last is largest.
// Only accept even tokens, expect the last (largest) even index. // Only accept even tokens, expect the last (largest) even index.
accept_token = [](int i, float) { return i % 2 == 0; }; accept_token = [](int i, float) { return i % 2 == 0; };
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature, sample =
accept_token); SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
EXPECT_EQ(sample, 50); // Last even index. EXPECT_EQ(sample, 50); // Last even index.
// Reset the logits to a positive, increasing sequence and take Softmax. // Reset the logits to a positive, increasing sequence and take Softmax.
std::iota(logits.begin(), logits.end(), 1.0f); std::iota(logits.begin(), logits.end(), 1.0f);
Softmax(logits.data(), kSize); Softmax(logits.data(), kSize);
// Sample from the top 3, expect one of the top 3 even indices. // Sample from the top 3, expect one of the top 3 even indices.
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature, sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
accept_token); accept_token);
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46); EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
} }
@ -592,7 +592,7 @@ void TestSampleTopK() {
// even for k=3. // even for k=3.
temperature = 0.0f; temperature = 0.0f;
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature, sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
accept_token); accept_token);
EXPECT_EQ(sample, 50); EXPECT_EQ(sample, 50);
} }

View File

@ -189,6 +189,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
const ModelInfo& Info() const { return info_; } const ModelInfo& Info() const { return info_; }
private: private:
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
// weights file, so we can remove the need for this struct entirely.
ModelInfo info_; ModelInfo info_;
}; };