mirror of https://github.com/google/gemma.cpp.git
Eliminated TConfig.
Changed CompressedLayer and CompressedWeights to be constructed with an instance of a LayerConfig and WeightsConfig respectively. Added CompressedModel to remove ByteStorageT and get rid of most of the type casting, as well as allowing the default destructor to be used and work properly. Adjusted WeightsWrapper and ForwardLayer etc to match. The only remaining template arg is the weight type. This enables all the instantiations to be deleted, apart from one per type. It also enables (but not yet done) the config to be stored in the blob file instead of having to be specified separately. Reduces the size of the gemma_lib and weights shared libraries by a factor of 4.3 and 3.2 respectively. PiperOrigin-RevId: 686870060
This commit is contained in:
parent
a4d6adbc43
commit
0d68555f87
58
BUILD.bazel
58
BUILD.bazel
|
|
@ -104,8 +104,6 @@ cc_test(
|
|||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
|
|
@ -183,7 +181,10 @@ cc_test(
|
|||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = ["gemma/common.cc"],
|
||||
srcs = [
|
||||
"gemma/common.cc",
|
||||
"gemma/configs.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
|
|
@ -195,12 +196,20 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "configs_test",
|
||||
srcs = ["gemma/configs_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
"@googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
|
|
@ -219,7 +228,6 @@ cc_library(
|
|||
":common",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
|
|
@ -239,30 +247,10 @@ cc_library(
|
|||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/gemma.cc",
|
||||
"gemma/instantiations/27b_bf16.cc",
|
||||
"gemma/instantiations/27b_f32.cc",
|
||||
"gemma/instantiations/27b_sfp.cc",
|
||||
"gemma/instantiations/2b_bf16.cc",
|
||||
"gemma/instantiations/2b_f32.cc",
|
||||
"gemma/instantiations/2b_sfp.cc",
|
||||
"gemma/instantiations/7b_bf16.cc",
|
||||
"gemma/instantiations/7b_f32.cc",
|
||||
"gemma/instantiations/7b_sfp.cc",
|
||||
"gemma/instantiations/9b_bf16.cc",
|
||||
"gemma/instantiations/9b_f32.cc",
|
||||
"gemma/instantiations/9b_sfp.cc",
|
||||
"gemma/instantiations/tiny_bf16.cc",
|
||||
"gemma/instantiations/tiny_f32.cc",
|
||||
"gemma/instantiations/tiny_sfp.cc",
|
||||
"gemma/instantiations/gr2b_bf16.cc",
|
||||
"gemma/instantiations/gr2b_f32.cc",
|
||||
"gemma/instantiations/gr2b_sfp.cc",
|
||||
"gemma/instantiations/gemma2_2b_bf16.cc",
|
||||
"gemma/instantiations/gemma2_2b_f32.cc",
|
||||
"gemma/instantiations/gemma2_2b_sfp.cc",
|
||||
"gemma/instantiations/paligemma_224_bf16.cc",
|
||||
"gemma/instantiations/paligemma_224_f32.cc",
|
||||
"gemma/instantiations/paligemma_224_sfp.cc",
|
||||
"gemma/instantiations/bf16.cc",
|
||||
"gemma/instantiations/f32.cc",
|
||||
"gemma/instantiations/nuq.cc",
|
||||
"gemma/instantiations/sfp.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
|
|
@ -327,8 +315,6 @@ cc_library(
|
|||
":threading",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -367,7 +353,6 @@ cc_test(
|
|||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":tokenizer",
|
||||
"@googletest//:gtest_main",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
|
|
@ -396,7 +381,6 @@ cc_binary(
|
|||
name = "single_benchmark",
|
||||
srcs = ["evals/benchmark.cc"],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
|
|
@ -405,7 +389,6 @@ cc_binary(
|
|||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
|
@ -429,13 +412,11 @@ cc_binary(
|
|||
"evals/debug_prompt.cc",
|
||||
],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
"@nlohmann_json//:json",
|
||||
],
|
||||
)
|
||||
|
|
@ -444,7 +425,6 @@ cc_binary(
|
|||
name = "gemma_mmlu",
|
||||
srcs = ["evals/run_mmlu.cc"],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":gemma_lib",
|
||||
|
|
@ -488,7 +468,6 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":prompt",
|
||||
":weights",
|
||||
|
|
@ -508,7 +487,6 @@ cc_library(
|
|||
"backprop/forward_scalar.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":prompt",
|
||||
":weights",
|
||||
|
|
@ -525,7 +503,6 @@ cc_test(
|
|||
"backprop/test_util.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":prompt",
|
||||
|
|
@ -599,6 +576,7 @@ cc_test(
|
|||
":threading",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:sfp",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -68,34 +68,15 @@ set(SOURCES
|
|||
gemma/activations.h
|
||||
gemma/common.cc
|
||||
gemma/common.h
|
||||
gemma/configs.cc
|
||||
gemma/configs.h
|
||||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/instantiations/27b_bf16.cc
|
||||
gemma/instantiations/27b_f32.cc
|
||||
gemma/instantiations/27b_sfp.cc
|
||||
gemma/instantiations/2b_bf16.cc
|
||||
gemma/instantiations/2b_f32.cc
|
||||
gemma/instantiations/2b_sfp.cc
|
||||
gemma/instantiations/7b_bf16.cc
|
||||
gemma/instantiations/7b_f32.cc
|
||||
gemma/instantiations/7b_sfp.cc
|
||||
gemma/instantiations/9b_bf16.cc
|
||||
gemma/instantiations/9b_f32.cc
|
||||
gemma/instantiations/9b_sfp.cc
|
||||
gemma/instantiations/gr2b_bf16.cc
|
||||
gemma/instantiations/gr2b_f32.cc
|
||||
gemma/instantiations/gr2b_sfp.cc
|
||||
gemma/instantiations/tiny_bf16.cc
|
||||
gemma/instantiations/tiny_f32.cc
|
||||
gemma/instantiations/tiny_sfp.cc
|
||||
gemma/instantiations/gemma2_2b_bf16.cc
|
||||
gemma/instantiations/gemma2_2b_f32.cc
|
||||
gemma/instantiations/gemma2_2b_sfp.cc
|
||||
gemma/instantiations/paligemma_224_bf16.cc
|
||||
gemma/instantiations/paligemma_224_f32.cc
|
||||
gemma/instantiations/paligemma_224_sfp.cc
|
||||
gemma/instantiations/bf16.cc
|
||||
gemma/instantiations/f32.cc
|
||||
gemma/instantiations/nuq.cc
|
||||
gemma/instantiations/sfp.cc
|
||||
gemma/kv_cache.cc
|
||||
gemma/kv_cache.h
|
||||
gemma/tokenizer.cc
|
||||
|
|
|
|||
|
|
@ -18,32 +18,27 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h" // MatStorageT
|
||||
#include "util/allocator.h" // ByteStorageT
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
template <typename T>
|
||||
struct ForwardLayer {
|
||||
ForwardLayer()
|
||||
: input("input", kSeqLen, kModelDim),
|
||||
pre_att_rms_out("pre_att_rms_out", kSeqLen, kModelDim),
|
||||
qkv("qkv", kSeqLen * (kHeads + 2), kQKVDim),
|
||||
att("att", kSeqLen * kHeads, kSeqLen),
|
||||
att_out("att_out", kSeqLen * kHeads, kQKVDim),
|
||||
att_post1("att_post1", kSeqLen, kModelDim),
|
||||
attention_out("attention_out", kSeqLen, kModelDim),
|
||||
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", kSeqLen, kModelDim),
|
||||
ffw_hidden("ffw_hidden", kSeqLen, kFFHiddenDim * 2),
|
||||
ffw_hidden_gated("ffw_hidden_gated", kSeqLen, kFFHiddenDim) {}
|
||||
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
ForwardLayer(const LayerConfig& config, size_t seq_len)
|
||||
: input("input", seq_len, config.model_dim),
|
||||
pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
|
||||
qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
|
||||
att("att", seq_len * config.heads, seq_len),
|
||||
att_out("att_out", seq_len * config.heads, config.qkv_dim),
|
||||
att_post1("att_post1", seq_len, config.model_dim),
|
||||
attention_out("attention_out", seq_len, config.model_dim),
|
||||
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
|
||||
ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
|
||||
ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
|
||||
layer_config(config) {}
|
||||
|
||||
MatStorageT<T> input;
|
||||
MatStorageT<T> pre_att_rms_out;
|
||||
|
|
@ -55,56 +50,30 @@ struct ForwardLayer {
|
|||
MatStorageT<T> bf_pre_ffw_rms_out;
|
||||
MatStorageT<T> ffw_hidden;
|
||||
MatStorageT<T> ffw_hidden_gated;
|
||||
const LayerConfig& layer_config;
|
||||
};
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
template <typename T>
|
||||
struct ForwardPass {
|
||||
ForwardPass()
|
||||
: final_layer_output("final_layer_output", kSeqLen, kModelDim),
|
||||
final_norm_output("final_norm_output", kSeqLen, kModelDim),
|
||||
logits("logits", kSeqLen, kVocabSize),
|
||||
probs("probs", kSeqLen, kVocabSize) {
|
||||
} // prevents placement-new calling memset
|
||||
ForwardPass(const ModelConfig& config)
|
||||
: final_layer_output("final_layer_output", config.seq_len,
|
||||
config.model_dim),
|
||||
final_norm_output("final_norm_output", config.seq_len,
|
||||
config.model_dim),
|
||||
logits("logits", config.seq_len, config.vocab_size),
|
||||
probs("probs", config.seq_len, config.vocab_size),
|
||||
weights_config(config) {
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
layers.emplace_back(layer_config, config.seq_len);
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
|
||||
std::array<ForwardLayer<T, TConfig>, kLayers> layers;
|
||||
std::vector<ForwardLayer<T>> layers;
|
||||
MatStorageT<T> final_layer_output;
|
||||
MatStorageT<T> final_norm_output;
|
||||
MatStorageT<T> logits;
|
||||
MatStorageT<T> probs;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct AllocateForwardPass {
|
||||
ByteStorageT operator()() const {
|
||||
ByteStorageT c_weights_u8 = AllocateSizeof<ForwardPass<float, TConfig>>();
|
||||
auto* c_weights =
|
||||
reinterpret_cast<ForwardPass<float, TConfig>*>(c_weights_u8.get());
|
||||
new (c_weights) ForwardPass<float, TConfig>();
|
||||
return c_weights_u8;
|
||||
}
|
||||
};
|
||||
|
||||
// Owns activations and undoes the type erasure of AllocateAligned.
|
||||
template<typename T, typename TConfig>
|
||||
class ActivationsWrapper {
|
||||
using WrappedT = ForwardPass<T, TConfig>;
|
||||
|
||||
public:
|
||||
ActivationsWrapper()
|
||||
: data_(AllocateSizeof<WrappedT>()),
|
||||
activations_(*(new(data_.get()) WrappedT())) {}
|
||||
|
||||
const WrappedT& get() const { return activations_; }
|
||||
WrappedT& get() { return activations_; }
|
||||
|
||||
private:
|
||||
ByteStorageT data_;
|
||||
WrappedT& activations_;
|
||||
const ModelConfig& weights_config;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -53,45 +54,41 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <size_t kCols, size_t kRows>
|
||||
void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
|
||||
HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols,
|
||||
const float* HWY_RESTRICT x, // num_tokens * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t num_tokens,
|
||||
size_t cols, size_t rows, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // kRows * kCols,
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0]));
|
||||
hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t voffs = pos * kRows;
|
||||
const size_t xoffs = pos * kCols;
|
||||
for (size_t j = 0; j < kRows; ++j) {
|
||||
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols);
|
||||
MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs],
|
||||
kCols);
|
||||
const size_t voffs = pos * rows;
|
||||
const size_t xoffs = pos * cols;
|
||||
for (size_t j = 0; j < rows; ++j) {
|
||||
MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols);
|
||||
MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t kHeads, size_t kCols, size_t kRows>
|
||||
void MultiHeadMatMulVJP(
|
||||
const float* HWY_RESTRICT weights, // kHeads * kRows * kCols
|
||||
const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols
|
||||
HWY_INLINE void MultiHeadMatMulVJP(
|
||||
const float* HWY_RESTRICT weights, // heads * kRows * kCols
|
||||
const float* HWY_RESTRICT x, // num_tokens * heads * kCols
|
||||
const float* HWY_RESTRICT v, // num_tokens * kRows
|
||||
size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // kHeads * kRows * kCols
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols
|
||||
size_t heads, size_t cols, size_t rows, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, // heads * kRows * kCols
|
||||
float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols
|
||||
hwy::ThreadPool& pool) {
|
||||
hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0]));
|
||||
hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0]));
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t j = 0; j < kRows; ++j) {
|
||||
for (size_t h = 0; h < kHeads; ++h) {
|
||||
MulByConstAndAdd(v[pos * kRows + j],
|
||||
&x[pos * kHeads * kCols + h * kCols],
|
||||
&grad_w[h * kRows * kCols + j * kCols], kCols);
|
||||
MulByConstAndAdd(v[pos * kRows + j],
|
||||
&weights[h * kRows * kCols + j * kCols],
|
||||
&grad_x[pos * kHeads * kCols + h * kCols], kCols);
|
||||
for (size_t j = 0; j < rows; ++j) {
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols],
|
||||
&grad_w[h * rows * cols + j * cols], cols);
|
||||
MulByConstAndAdd(v[pos * rows + j],
|
||||
&weights[h * rows * cols + j * cols],
|
||||
&grad_x[pos * heads * cols + h * cols], cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -168,39 +165,39 @@ static HWY_NOINLINE void InputEmbeddingVJP(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename TConfig, typename LayerT>
|
||||
void LayerVJP(const LayerT& weights,
|
||||
const ForwardLayer<float, TConfig>& forward,
|
||||
template <typename T>
|
||||
void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||
const ForwardLayer<float>& forward,
|
||||
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
|
||||
LayerT& grad, ForwardLayer<float, TConfig>& backward,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static const float kQueryScale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t qkv_dim = config.qkv_dim;
|
||||
const size_t heads = config.heads;
|
||||
const size_t seq_len = forward.input.Rows();
|
||||
const size_t ff_hidden_dim = config.ff_hidden_dim;
|
||||
const float query_scale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
|
||||
HWY_ASSERT(num_tokens <= seq_len);
|
||||
|
||||
MatMulVJP<kFFHiddenDim, kModelDim>(
|
||||
weights.linear_w.data(), forward.ffw_hidden_gated.data(), next_layer_grad,
|
||||
num_tokens, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
|
||||
pool);
|
||||
MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
|
||||
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
|
||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||
const size_t hidden_offset = pos * ff_hidden_dim * 2;
|
||||
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
|
||||
const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim;
|
||||
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
|
||||
const float* HWY_RESTRICT b_out_gated =
|
||||
backward.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
||||
backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
|
||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
|
||||
float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim;
|
||||
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
DF df;
|
||||
for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) {
|
||||
for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) {
|
||||
const auto y = Load(df, f_out + i);
|
||||
const auto x = Load(df, f_out_mul + i);
|
||||
const auto v = Load(df, b_out_gated + i);
|
||||
|
|
@ -209,101 +206,94 @@ void LayerVJP(const LayerT& weights,
|
|||
}
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, kFFHiddenDim * 2>(
|
||||
weights.gating_einsum_w.data(),
|
||||
forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(),
|
||||
MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
|
||||
num_tokens, grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(),
|
||||
forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(),
|
||||
kModelDim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.data(),
|
||||
backward.attention_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(next_layer_grad + pos * kModelDim,
|
||||
backward.attention_out.data() + pos * kModelDim, kModelDim);
|
||||
AddFrom(next_layer_grad + pos * model_dim,
|
||||
backward.attention_out.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
backward.qkv.ZeroInit();
|
||||
|
||||
MultiHeadMatMulVJP<kHeads, kQKVDim, kModelDim>(
|
||||
weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(), num_tokens,
|
||||
grad.attn_vec_einsum_w.data(), backward.att_out.data(), pool);
|
||||
MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(), heads, qkv_dim, model_dim,
|
||||
num_tokens, grad.attn_vec_einsum_w.data(),
|
||||
backward.att_out.data(), pool);
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
const float* HWY_RESTRICT b_att_out =
|
||||
backward.att_out.data() + (pos * kHeads + head) * kQKVDim;
|
||||
backward.att_out.data() + (pos * heads + head) * qkv_dim;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
|
||||
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
|
||||
b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim);
|
||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim);
|
||||
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
|
||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
|
||||
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
|
||||
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
|
||||
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
|
||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||
float* HWY_RESTRICT b_kv =
|
||||
backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos);
|
||||
backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT b_q =
|
||||
backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
MulByConst(kQueryScale, b_q, kQKVDim);
|
||||
Rope(b_q, kQKVDim, inv_timescale.Const(), -pos);
|
||||
backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
|
||||
MulByConst(query_scale, b_q, qkv_dim);
|
||||
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, (kHeads + 2) * kQKVDim>(
|
||||
weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), num_tokens,
|
||||
MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
||||
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.data(),
|
||||
forward.input.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
kModelDim, num_tokens,
|
||||
grad.pre_attention_norm_scale.data(),
|
||||
backward.input.data(), pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(), model_dim, num_tokens,
|
||||
grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(backward.attention_out.data() + pos * kModelDim,
|
||||
backward.input.data() + pos * kModelDim, kModelDim);
|
||||
AddFrom(backward.attention_out.data() + pos * model_dim,
|
||||
backward.input.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -342,20 +332,22 @@ static HWY_NOINLINE void CrossEntropyLossGrad(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename TConfig, typename WeightsT, typename LayerT>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights,
|
||||
const ForwardPass<float, TConfig>& forward,
|
||||
WeightsT& grad,
|
||||
ForwardPass<float, TConfig>& backward,
|
||||
template <typename T>
|
||||
void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
static_assert(!TConfig::kAbsolutePE);
|
||||
static_assert(TConfig::kPostNorm == PostNormType::None);
|
||||
static_assert(TConfig::kKVHeads == 1);
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t kVocabSize = config.vocab_size;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t kLayers = config.layer_configs.size();
|
||||
const float kEmbScaling = EmbeddingScaling(model_dim);
|
||||
HWY_ASSERT(!config.absolute_pe);
|
||||
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
|
||||
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
|
||||
|
||||
HWY_DASSERT(prompt.context_size > 0);
|
||||
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
||||
|
|
@ -370,42 +362,38 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights,
|
|||
kVocabSize);
|
||||
}
|
||||
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
|
||||
SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP<kModelDim, kVocabSize>(
|
||||
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
|
||||
backward.logits.data(), num_tokens,
|
||||
grad.embedder_input_embedding.data(), backward.final_norm_output.data(),
|
||||
MatMulVJP(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(), backward.logits.data(), model_dim,
|
||||
kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(), pool);
|
||||
|
||||
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(), model_dim, num_tokens,
|
||||
grad.final_norm_scale.data(), backward.final_layer_output.data(),
|
||||
pool);
|
||||
|
||||
RMSNormVJP(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(),
|
||||
kModelDim, num_tokens,
|
||||
grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), pool);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
auto layer_config = config.layer_configs[layer];
|
||||
// TODO(szabadka) Implement Griffin layer vjp.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
|
||||
float* next_layer_grad = layer + 1 < kLayers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
LayerVJP<TConfig, LayerT>(*weights.GetLayer(layer), forward.layers[layer],
|
||||
next_layer_grad, num_tokens,
|
||||
*grad.GetLayer(layer), backward.layers[layer],
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), kModelDim);
|
||||
grad.embedder_input_embedding.data(), model_dim);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -38,44 +38,15 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename TConfig>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ByteStorageT& weights_u8,
|
||||
const ByteStorageT& forward_u8,
|
||||
ByteStorageT& grad_u8,
|
||||
ByteStorageT& backward_u8,
|
||||
void CrossEntropyLossBackwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
using TWeights = CompressedWeights<TConfig>;
|
||||
const auto& weights = *reinterpret_cast<const TWeights*>(weights_u8.get());
|
||||
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
|
||||
using TAct = ForwardPass<float, TConfig>;
|
||||
const auto& forward = *reinterpret_cast<const TAct*>(forward_u8.get());
|
||||
auto& backward = *reinterpret_cast<TAct*>(backward_u8.get());
|
||||
CrossEntropyLossBackwardPass<TConfig, CompressedWeights<TConfig>,
|
||||
CompressedLayer<TConfig>>(
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
const ByteStorageT& forward,
|
||||
ByteStorageT& grad, ByteStorageT& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
// TODO(janwas): use CallFunctorForModel
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CrossEntropyLossBackwardPass<ConfigGemma2B<float>>(
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
CrossEntropyLossBackwardPass<ConfigGemmaTiny<float>>(
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -87,14 +58,15 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(CrossEntropyLossBackwardPassT);
|
||||
|
||||
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
const ByteStorageT& forward,
|
||||
ByteStorageT& grad, ByteStorageT& backward,
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
|
||||
model, prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
prompt, weights, forward, grad, backward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,17 +16,19 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
const ByteStorageT& forward,
|
||||
ByteStorageT& grad, ByteStorageT& backward,
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
const ForwardPass<float>& forward,
|
||||
ModelWeightsPtrs<float>& grad,
|
||||
ForwardPass<float>& backward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
|
|
|
|||
|
|
@ -125,65 +125,64 @@ void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv,
|
||||
size_t num_tokens, size_t kHeads, size_t kQKVDim,
|
||||
size_t kSeqLen) {
|
||||
size_t num_tokens, size_t kHeads, size_t qkv_dim,
|
||||
size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = pos * (kHeads + 2) * kQKVDim;
|
||||
memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0]));
|
||||
const size_t offset = pos * (kHeads + 2) * qkv_dim;
|
||||
memset(dqkv + offset, 0, (kHeads + 1) * qkv_dim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const size_t qoffs = (pos * (kHeads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * kHeads * seq_len;
|
||||
const T* q = qkv + qoffs;
|
||||
const T* dout = doutput + aoffs;
|
||||
T* dq = dqkv + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * qkv_dim;
|
||||
const T* k = qkv + koffs;
|
||||
T* dk = dqkv + koffs;
|
||||
MulByConstAndAddT(dout[pos2], k, dq, kQKVDim);
|
||||
MulByConstAndAddT(dout[pos2], q, dk, kQKVDim);
|
||||
MulByConstAndAddT(dout[pos2], k, dq, qkv_dim);
|
||||
MulByConstAndAddT(dout[pos2], q, dk, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens,
|
||||
size_t kHeads, size_t kSeqLen) {
|
||||
void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, size_t kHeads,
|
||||
size_t seq_len) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
|
||||
size_t offset = pos * kHeads * seq_len + head * seq_len;
|
||||
SoftmaxVJPT(y + offset, dy + offset, pos + 1);
|
||||
memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
|
||||
memset(dy + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput,
|
||||
T* dqkv, T* dattention, size_t num_tokens,
|
||||
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
|
||||
T* dqkv, T* dattention, size_t num_tokens, size_t kHeads,
|
||||
size_t qkv_dim, size_t seq_len) {
|
||||
auto v_offset = [&](size_t pos) {
|
||||
return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
return (pos * (kHeads + 2) + kHeads + 1) * qkv_dim;
|
||||
};
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0]));
|
||||
memset(&dqkv[v_offset(pos)], 0, qkv_dim * sizeof(qkv[0]));
|
||||
}
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim;
|
||||
const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen;
|
||||
const size_t offset = head * qkv_dim + pos * kHeads * qkv_dim;
|
||||
const size_t aoffset = head * seq_len + pos * kHeads * seq_len;
|
||||
const T* att = &attention[aoffset];
|
||||
const T* dout = &doutput[offset];
|
||||
T* datt = &dattention[aoffset];
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim);
|
||||
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim);
|
||||
datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], qkv_dim);
|
||||
MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -199,77 +198,76 @@ void InputEmbeddingVJPT(const T* w, const std::vector<int>& tokens, T scaling,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
void LayerVJP(const CompressedLayer<TConfig>& weights,
|
||||
const ForwardLayer<T, TConfig>& forward, const T* dy,
|
||||
CompressedLayer<TConfig>& grad,
|
||||
ForwardLayer<T, TConfig>& backward, size_t num_tokens) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim));
|
||||
template <typename T>
|
||||
void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||
const ForwardLayer<T>& forward, const T* dy,
|
||||
LayerWeightsPtrs<T>& grad, ForwardLayer<T>& backward,
|
||||
size_t num_tokens) {
|
||||
const LayerConfig& layer_config = weights.layer_config;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t seq_len = forward.input.Rows();
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kHeads = layer_config.heads;
|
||||
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
|
||||
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
|
||||
|
||||
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
|
||||
dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(),
|
||||
kModelDim, kFFHiddenDim, num_tokens);
|
||||
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
|
||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
|
||||
kFFHiddenDim, num_tokens);
|
||||
|
||||
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
|
||||
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim,
|
||||
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(),
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
kModelDim, num_tokens);
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim);
|
||||
AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
|
||||
|
||||
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(),
|
||||
grad.attn_vec_einsum_w.data(),
|
||||
backward.att_out.data(),
|
||||
kHeads, kModelDim, kQKVDim, num_tokens);
|
||||
grad.attn_vec_einsum_w.data(), backward.att_out.data(),
|
||||
kHeads, model_dim, qkv_dim, num_tokens);
|
||||
|
||||
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
|
||||
backward.att_out.data(), backward.qkv.data(),
|
||||
backward.att.data(), num_tokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(),
|
||||
num_tokens, kHeads, kSeqLen);
|
||||
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
|
||||
seq_len);
|
||||
|
||||
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
|
||||
backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen);
|
||||
backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * kQKVDim, kQKVDim, -pos);
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), grad.qkv_einsum_w.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
(kHeads + 2) * kQKVDim, kModelDim, num_tokens);
|
||||
backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
|
||||
num_tokens);
|
||||
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
grad.pre_attention_norm_scale.data(),
|
||||
backward.input.data(), kModelDim, num_tokens);
|
||||
grad.pre_attention_norm_scale.data(), backward.input.data(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(backward.attention_out.data(), backward.input.data(),
|
||||
num_tokens * kModelDim);
|
||||
num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -296,56 +294,54 @@ void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
template <typename T>
|
||||
void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
const ForwardPass<T, TConfig>& forward,
|
||||
CompressedWeights<TConfig>& grad,
|
||||
ForwardPass<T, TConfig>& backward) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const ForwardPass<T>& forward,
|
||||
ModelWeightsPtrs<T>& grad,
|
||||
ForwardPass<T>& backward) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
kVocabSize);
|
||||
vocab_size);
|
||||
|
||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(),
|
||||
kVocabSize, num_tokens);
|
||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
|
||||
num_tokens);
|
||||
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(TConfig::kFinalCap, forward.logits.data() + i * kVocabSize,
|
||||
backward.logits.data() + i * kVocabSize, kVocabSize);
|
||||
SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
|
||||
backward.logits.data() + i * vocab_size, vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(),
|
||||
backward.logits.data(),
|
||||
grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(),
|
||||
kVocabSize, kModelDim, num_tokens);
|
||||
MatMulVJPT(
|
||||
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
|
||||
backward.logits.data(), grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(),
|
||||
grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), kModelDim, num_tokens);
|
||||
backward.final_norm_output.data(), grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), model_dim, num_tokens);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
T* next_layer_grad = layer + 1 < kLayers
|
||||
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
|
||||
T* next_layer_grad = layer + 1 < layers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
||||
}
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(),
|
||||
tokens, kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), kModelDim);
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), model_dim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@
|
|||
#include <stdio.h>
|
||||
#include <string.h> // memcpy
|
||||
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
|
@ -384,44 +383,49 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TestConfig : ConfigBaseGemmaV2 {
|
||||
using Weight = T;
|
||||
static constexpr int kSeqLen = 18;
|
||||
static constexpr int kVocabSize = 12;
|
||||
static constexpr int kModelDim = 32;
|
||||
static constexpr int kHeads = 3;
|
||||
static constexpr int kQKVDim = 12;
|
||||
static constexpr int kFFHiddenDim = 48;
|
||||
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
|
||||
FixedLayerConfig<2>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static ModelConfig TestConfig() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 12;
|
||||
config.seq_len = 18;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 48,
|
||||
.heads = 3,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 12,
|
||||
};
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
TEST(BackPropTest, LayerVJP) {
|
||||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
const size_t kOutputSize = TestConfig<T>::kSeqLen * TestConfig<T>::kModelDim;
|
||||
CompressedLayer<TestConfig<T>> weights;
|
||||
CompressedLayer<TestConfig<T>> grad;
|
||||
ForwardLayer<T, TestConfig<T>> forward;
|
||||
ForwardLayer<T, TestConfig<T>> backward = {};
|
||||
CompressedLayer<TestConfig<TC>> c_weights;
|
||||
ForwardLayer<TC, TestConfig<TC>> c_forward;
|
||||
std::array<T, kOutputSize> y;
|
||||
ModelConfig config = TestConfig();
|
||||
const size_t kOutputSize = config.seq_len * config.model_dim;
|
||||
LayerWeightsPtrs<T> weights(config.layer_configs[0]);
|
||||
LayerWeightsPtrs<T> grad(config.layer_configs[0]);
|
||||
ForwardLayer<T> forward(config.layer_configs[0], config.seq_len);
|
||||
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0]);
|
||||
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
||||
MatStorageT<T> y("y", kOutputSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutputSize, 1);
|
||||
std::array<TC, kOutputSize> c_y;
|
||||
MatStorageT<TC> c_y("c_y", kOutputSize, 1);
|
||||
const size_t num_tokens = 3;
|
||||
weights.Allocate();
|
||||
grad.Allocate();
|
||||
c_weights.Allocate();
|
||||
std::vector<MatStorage> layer_storage;
|
||||
weights.Allocate(layer_storage);
|
||||
grad.Allocate(layer_storage);
|
||||
c_weights.Allocate(layer_storage);
|
||||
backward.input.ZeroInit();
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
|
|
@ -432,7 +436,7 @@ TEST(BackPropTest, LayerVJP) {
|
|||
Complexify(forward.input, c_forward.input);
|
||||
auto func = [&]() {
|
||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * TestConfig<T>::kModelDim);
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
|
||||
};
|
||||
grad.ZeroInit(/*layer_idx=*/0);
|
||||
ApplyLayer(weights, forward, num_tokens, y.data());
|
||||
|
|
@ -447,12 +451,13 @@ TEST(BackPropTest, EndToEnd) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
WeightsWrapper<TestConfig<T>> weights;
|
||||
WeightsWrapper<TestConfig<T>> grad;
|
||||
ForwardPass<T, TestConfig<T>> forward;
|
||||
ForwardPass<T, TestConfig<T>> backward;
|
||||
WeightsWrapper<TestConfig<TC>> c_weights;
|
||||
ForwardPass<TC, TestConfig<TC>> c_forward;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<T> weights(config);
|
||||
WeightsWrapper<T> grad(config);
|
||||
ForwardPass<T> forward(config);
|
||||
ForwardPass<T> backward(config);
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
|
@ -474,9 +479,9 @@ TEST(BackPropTest, EndToEnd) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
void MulByConstAndAddT(T c, const CompressedLayer<TConfig>& x,
|
||||
CompressedLayer<TConfig>& out) {
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const LayerWeightsPtrs<T>& x,
|
||||
LayerWeightsPtrs<T>& out) {
|
||||
MulByConstAndAddT(c, x.pre_attention_norm_scale,
|
||||
out.pre_attention_norm_scale);
|
||||
MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w);
|
||||
|
|
@ -486,23 +491,23 @@ void MulByConstAndAddT(T c, const CompressedLayer<TConfig>& x,
|
|||
MulByConstAndAddT(c, x.linear_w, out.linear_w);
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
void MulByConstAndAddT(T c, const CompressedWeights<TConfig>& x,
|
||||
CompressedWeights<TConfig>& out) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const ModelWeightsPtrs<T>& x,
|
||||
ModelWeightsPtrs<T>& out) {
|
||||
const size_t layers = x.c_layers.size();
|
||||
MulByConstAndAddT(c, x.embedder_input_embedding,
|
||||
out.embedder_input_embedding);
|
||||
MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
for (size_t i = 0; i < layers; ++i) {
|
||||
MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluates forward pass on a batch.
|
||||
template <typename T, typename TConfig>
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<TConfig>& weights,
|
||||
ForwardPass<T, TConfig>& forward) {
|
||||
const WeightsWrapper<T>& weights,
|
||||
ForwardPass<T>& forward) {
|
||||
T loss = 0.0;
|
||||
for (const Prompt& prompt : batch) {
|
||||
loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward);
|
||||
|
|
@ -514,12 +519,11 @@ T CrossEntropyLossForwardPass(const std::vector<Prompt>& batch,
|
|||
// Evaluates forward pass on a batch by applying gradient with the given
|
||||
// learning rate. Does not update weights, but uses the given tmp weights
|
||||
// instead.
|
||||
template <typename T, typename TConfig>
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
|
||||
const WeightsWrapper<TConfig>& weights,
|
||||
const WeightsWrapper<TConfig>& grad,
|
||||
WeightsWrapper<TConfig>& tmp,
|
||||
ForwardPass<T, TConfig>& forward) {
|
||||
const WeightsWrapper<T>& weights,
|
||||
const WeightsWrapper<T>& grad,
|
||||
WeightsWrapper<T>& tmp, ForwardPass<T>& forward) {
|
||||
tmp.CopyFrom(weights);
|
||||
const T scale = -learning_rate / batch.size();
|
||||
MulByConstAndAddT(scale, grad.get(), tmp.get());
|
||||
|
|
@ -529,11 +533,9 @@ T CrossEntropyLossForwardPass(T learning_rate, const std::vector<Prompt>& batch,
|
|||
// Uses line search in the negative gradient direction to update weights. We do
|
||||
// this so that we can test that each step during the gradient descent can
|
||||
// decrease the objective function value.
|
||||
template <typename T, typename TConfig>
|
||||
T FindOptimalUpdate(const WeightsWrapper<TConfig>& grad,
|
||||
WeightsWrapper<TConfig>& weights,
|
||||
WeightsWrapper<TConfig>& tmp,
|
||||
ForwardPass<T, TConfig>& forward,
|
||||
template <typename T>
|
||||
T FindOptimalUpdate(const WeightsWrapper<T>& grad, WeightsWrapper<T>& weights,
|
||||
WeightsWrapper<T>& tmp, ForwardPass<T>& forward,
|
||||
const std::vector<Prompt>& batch, T loss,
|
||||
T initial_learning_rate) {
|
||||
T lr0 = initial_learning_rate;
|
||||
|
|
@ -568,13 +570,14 @@ TEST(BackProptest, Convergence) {
|
|||
std::mt19937 gen(42);
|
||||
using T = float;
|
||||
using TC = std::complex<double>;
|
||||
WeightsWrapper<TestConfig<T>> weights;
|
||||
WeightsWrapper<TestConfig<T>> grad;
|
||||
WeightsWrapper<TestConfig<T>> tmp;
|
||||
ForwardPass<T, TestConfig<T>> forward;
|
||||
ForwardPass<T, TestConfig<T>> backward;
|
||||
WeightsWrapper<TestConfig<TC>> c_weights;
|
||||
ForwardPass<TC, TestConfig<TC>> c_forward;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<T> weights(config);
|
||||
WeightsWrapper<T> grad(config);
|
||||
WeightsWrapper<T> tmp(config);
|
||||
ForwardPass<T> forward(config);
|
||||
ForwardPass<T> backward(config);
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
constexpr size_t kBatchSize = 5;
|
||||
ReverseSequenceSampler training_task({0, 0, 0, 1, 1});
|
||||
T learning_rate = 0.01;
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <cstdlib> // std::abs
|
||||
#include <random>
|
||||
|
|
@ -34,7 +33,6 @@
|
|||
#include "backprop/test_util.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -50,6 +48,7 @@
|
|||
#include "backprop/forward-inl.h"
|
||||
#include "compression/compress.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "util/allocator.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -85,7 +84,7 @@ void TestMatMulVJP() {
|
|||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
MatMulVJP<kCols, kRows>(weights.data(), x.data(), dy.data(), kTokens,
|
||||
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
|
||||
grad.data(), dx.data(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
|
@ -130,9 +129,8 @@ void TestMultiHeadMatMulVJP() {
|
|||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
MultiHeadMatMulVJP<kHeads, kCols, kRows>(
|
||||
weights.data(), x.data(), dy.data(), kTokens, grad.data(), dx.data(),
|
||||
pool);
|
||||
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
|
||||
kRows, kTokens, grad.data(), dx.data(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
|
|
@ -186,63 +184,63 @@ void TestRMSNormVJP() {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct TestConfig : ConfigBaseGemmaV2 {
|
||||
using Weight = T;
|
||||
static constexpr int kSeqLen = 24;
|
||||
static constexpr int kVocabSize = 16;
|
||||
static constexpr int kModelDim = 32;
|
||||
static constexpr int kHeads = 3;
|
||||
static constexpr int kQKVDim = 16;
|
||||
static constexpr int kFFHiddenDim = 64;
|
||||
static constexpr std::array<LayerAttentionType, 2> kLayerConfig =
|
||||
FixedLayerConfig<2>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static ModelConfig TestConfig() {
|
||||
ModelConfig config;
|
||||
config.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"};
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 16;
|
||||
config.seq_len = 24;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 64,
|
||||
.heads = 3,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 16,
|
||||
};
|
||||
config.layer_configs = {2, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<2>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
void TestEndToEnd() {
|
||||
std::mt19937 gen(42);
|
||||
hwy::ThreadPool pool(0);
|
||||
using WeightsF = CompressedWeights<TestConfig<float>>;
|
||||
using LayerF = CompressedLayer<TestConfig<float>>;
|
||||
WeightsWrapper<TestConfig<float>> weights;
|
||||
WeightsWrapper<TestConfig<float>> grad;
|
||||
ActivationsWrapper<float, TestConfig<float>> forward0;
|
||||
ActivationsWrapper<float, TestConfig<float>> forward1;
|
||||
ActivationsWrapper<float, TestConfig<float>> backward;
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<float> weights(config);
|
||||
WeightsWrapper<float> grad(config);
|
||||
ForwardPass<float> forward0(config);
|
||||
ForwardPass<float> forward1(config);
|
||||
ForwardPass<float> backward(config);
|
||||
using TC = std::complex<double>;
|
||||
WeightsWrapper<TestConfig<TC>> c_weights;
|
||||
ForwardPass<TC, TestConfig<TC>> c_forward;
|
||||
WeightsWrapper<TC> c_weights(config);
|
||||
ForwardPass<TC> c_forward(config);
|
||||
|
||||
ReverseSequenceSampler training_task({0, 0, 1, 1});
|
||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
RowVectorBatch<float> inv_timescale =
|
||||
Activations::CreateInvTimescale<TestConfig<float>>();
|
||||
RowVectorBatch<float> inv_timescale = Activations::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
RandInit(weights.get(), 1.0f, gen);
|
||||
|
||||
float loss0 = CrossEntropyLossForwardPass(
|
||||
prompt, weights.get(), forward0.get());
|
||||
float loss0 = CrossEntropyLossForwardPass(prompt, weights.get(), forward0);
|
||||
|
||||
float loss1 =
|
||||
CrossEntropyLossForwardPass<TestConfig<float>, WeightsF, LayerF>(
|
||||
prompt.tokens, prompt.context_size, weights.get(), forward1.get(),
|
||||
float loss1 = CrossEntropyLossForwardPass(
|
||||
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
||||
inv_timescale, pool);
|
||||
|
||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||
|
||||
grad.ZeroInit();
|
||||
CrossEntropyLossBackwardPass<TestConfig<float>, WeightsF, LayerF>(
|
||||
prompt, weights.get(), forward1.get(), grad.get(), backward.get(),
|
||||
inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||
backward, inv_timescale, pool);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
#include "backprop/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -93,28 +94,28 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs,
|
|||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename TConfig, typename LayerT>
|
||||
void ApplyForwardLayer(const LayerT& weights,
|
||||
ForwardLayer<float, TConfig>& activations,
|
||||
size_t num_tokens, float* HWY_RESTRICT output,
|
||||
template <typename T>
|
||||
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<float>& activations, size_t num_tokens,
|
||||
float* HWY_RESTRICT output,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static const float kQueryScale =
|
||||
const LayerConfig& config = weights.layer_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t kSeqLen = activations.input.Rows();
|
||||
const size_t kQKVDim = config.qkv_dim;
|
||||
const size_t kHeads = config.heads;
|
||||
static const float query_scale =
|
||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
|
||||
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
|
||||
activations.input.data(), kModelDim, num_tokens,
|
||||
activations.input.data(), model_dim, num_tokens,
|
||||
activations.pre_att_rms_out.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<(kHeads + 2) * kQKVDim, kModelDim>(
|
||||
weights.qkv_einsum_w, 0,
|
||||
activations.pre_att_rms_out.data() + pos * kModelDim,
|
||||
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
|
||||
activations.pre_att_rms_out.data() + pos * model_dim,
|
||||
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
|
|
@ -130,7 +131,7 @@ void ApplyForwardLayer(const LayerT& weights,
|
|||
float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
MulByConst(query_scale, q, kQKVDim);
|
||||
});
|
||||
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
|
|
@ -174,28 +175,28 @@ void ApplyForwardLayer(const LayerT& weights,
|
|||
activations.attention_out.ZeroInit();
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
weights.attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
MatVec(
|
||||
weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
||||
kQKVDim,
|
||||
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
||||
activations.att_post1.data() + pos * kModelDim, pool);
|
||||
AddFrom(activations.att_post1.data() + pos * kModelDim,
|
||||
activations.attention_out.data() + pos * kModelDim, kModelDim);
|
||||
activations.att_post1.data() + pos * model_dim, pool);
|
||||
AddFrom(activations.att_post1.data() + pos * model_dim,
|
||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.input.data() + pos * kModelDim,
|
||||
activations.attention_out.data() + pos * kModelDim, kModelDim);
|
||||
AddFrom(activations.input.data() + pos * model_dim,
|
||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
|
||||
activations.attention_out.data(), kModelDim, num_tokens,
|
||||
activations.attention_out.data(), model_dim, num_tokens,
|
||||
activations.bf_pre_ffw_rms_out.data(), pool);
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<kFFHiddenDim * 2, kModelDim>(
|
||||
weights.gating_einsum_w, 0,
|
||||
activations.bf_pre_ffw_rms_out.data() + pos * kModelDim,
|
||||
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
||||
activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
|
||||
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
|
|
@ -215,77 +216,76 @@ void ApplyForwardLayer(const LayerT& weights,
|
|||
}
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<kModelDim, kFFHiddenDim>(
|
||||
weights.linear_w, 0,
|
||||
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
||||
output + pos * kModelDim, pool);
|
||||
output + pos * model_dim, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.attention_out.data() + pos * kModelDim,
|
||||
output + pos * kModelDim, kModelDim);
|
||||
AddFrom(activations.attention_out.data() + pos * model_dim,
|
||||
output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TConfig, typename WeightsT, typename LayerT>
|
||||
template <typename T>
|
||||
float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||
size_t context_size, const WeightsT& weights,
|
||||
ForwardPass<float, TConfig>& forward,
|
||||
size_t context_size,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
const RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const float kEmbScaling = EmbeddingScaling<TConfig>();
|
||||
static_assert(!TConfig::kAbsolutePE);
|
||||
static_assert(TConfig::kPostNorm == PostNormType::None);
|
||||
static_assert(TConfig::kKVHeads == 1);
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const float emb_scaling = EmbeddingScaling(model_dim);
|
||||
HWY_ASSERT(!config.absolute_pe);
|
||||
HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None);
|
||||
HWY_ASSERT(config.layer_configs[0].kv_heads == 1);
|
||||
|
||||
HWY_DASSERT(context_size > 0);
|
||||
HWY_DASSERT(context_size < prompt.size());
|
||||
const size_t num_tokens = prompt.size() - 1;
|
||||
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling,
|
||||
forward.layers[0].input.data(), kModelDim, kVocabSize);
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
|
||||
forward.layers[0].input.data(), model_dim, vocab_size);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
||||
auto type = config.layer_configs[layer].type;
|
||||
// TODO(szabadka) Implement Griffin layer.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
float* HWY_RESTRICT output = layer + 1 < kLayers ?
|
||||
forward.layers[layer + 1].input.data() :
|
||||
forward.final_layer_output.data();
|
||||
ApplyForwardLayer<TConfig, LayerT>(*weights.GetLayer(layer),
|
||||
forward.layers[layer], num_tokens,
|
||||
output, inv_timescale, pool);
|
||||
float* HWY_RESTRICT output = layer + 1 < layers
|
||||
? forward.layers[layer + 1].input.data()
|
||||
: forward.final_layer_output.data();
|
||||
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
|
||||
num_tokens, output, inv_timescale, pool);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
kModelDim, num_tokens, forward.final_norm_output.data(), pool);
|
||||
forward.final_layer_output.data(), model_dim, num_tokens,
|
||||
forward.final_norm_output.data(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec<kVocabSize, kModelDim>(
|
||||
weights.embedder_input_embedding, 0,
|
||||
forward.final_norm_output.data() + pos * kModelDim,
|
||||
forward.logits.data() + pos * kVocabSize, pool);
|
||||
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
|
||||
forward.final_norm_output.data() + pos * model_dim,
|
||||
forward.logits.data() + pos * vocab_size, pool);
|
||||
}
|
||||
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(TConfig::kFinalCap,
|
||||
forward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||
num_tokens * kVocabSize * sizeof(forward.logits.At(0)));
|
||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize);
|
||||
Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
|
||||
kVocabSize, pool);
|
||||
vocab_size, pool);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -17,8 +17,9 @@
|
|||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
|
|
@ -36,38 +37,13 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename TConfig>
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ByteStorageT& weights_u8,
|
||||
ByteStorageT& forward_u8,
|
||||
float CrossEntropyLossForwardPassT(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
const auto& weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
auto& forward =
|
||||
*reinterpret_cast<ForwardPass<float, TConfig>*>(forward_u8.get());
|
||||
return CrossEntropyLossForwardPass<TConfig, CompressedWeights<TConfig>,
|
||||
CompressedLayer<TConfig>>(
|
||||
prompt.tokens, prompt.context_size, weights, forward, inv_timescale,
|
||||
pool);
|
||||
}
|
||||
|
||||
float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
ByteStorageT& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
// TODO(janwas): use CallFunctorForModel
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return CrossEntropyLossForwardPass<ConfigGemma2B<float>>(
|
||||
prompt, weights, forward, inv_timescale, pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return CrossEntropyLossForwardPass<ConfigGemmaTiny<float>>(
|
||||
prompt, weights, forward, inv_timescale, pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
|
||||
weights, forward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -79,13 +55,13 @@ namespace gcpp {
|
|||
|
||||
HWY_EXPORT(CrossEntropyLossForwardPassT);
|
||||
|
||||
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
ByteStorageT& forward,
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool) {
|
||||
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
|
||||
model, prompt, weights, forward, inv_timescale, pool);
|
||||
prompt, weights, forward, inv_timescale, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,16 +16,17 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt,
|
||||
const ByteStorageT& weights,
|
||||
ByteStorageT& forward,
|
||||
float CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const ModelWeightsPtrs<float>& weights,
|
||||
ForwardPass<float>& forward,
|
||||
RowVectorBatch<float>& inv_timescale,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
|
|
|
|||
|
|
@ -128,107 +128,106 @@ void InputEmbedding(const T* w, const std::vector<int>& tokens, T scaling,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void MaskedAttention(const T* qkv, T* output, size_t num_tokens,
|
||||
size_t kHeads, size_t kQKVDim, size_t kSeqLen) {
|
||||
void MaskedAttention(const T* qkv, T* output, size_t num_tokens, size_t heads,
|
||||
size_t qkv_dim, size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
const size_t qoffset = pos * (kHeads + 2) * kQKVDim;
|
||||
const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen;
|
||||
const T* q = qkv + qoffset + head * kQKVDim;
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
const size_t qoffset = pos * (heads + 2) * qkv_dim;
|
||||
const size_t aoffset = pos * heads * seq_len + head * seq_len;
|
||||
const T* q = qkv + qoffset + head * qkv_dim;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
output[aoffset + pos2] = DotT(q, k, kQKVDim);
|
||||
const T* k = qkv + (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
output[aoffset + pos2] = DotT(q, k, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) {
|
||||
void MaskedSoftmax(T* x, size_t num_tokens, size_t heads, size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
size_t offset = pos * kHeads * kSeqLen + head * kSeqLen;
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
size_t offset = pos * heads * seq_len + head * seq_len;
|
||||
Softmax(x + offset, pos + 1);
|
||||
memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T));
|
||||
memset(x + offset + pos + 1, 0, (seq_len - pos - 1) * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void MixByAttention(const T* qkv, const T* attention, T* output,
|
||||
size_t num_tokens, size_t kHeads, size_t kQKVDim,
|
||||
size_t kSeqLen) {
|
||||
size_t num_tokens, size_t heads, size_t qkv_dim,
|
||||
size_t seq_len) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen];
|
||||
T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim];
|
||||
memset(out, 0, kQKVDim * sizeof(out[0]));
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
const T* att = &attention[pos * heads * seq_len + head * seq_len];
|
||||
T* out = &output[head * qkv_dim + pos * heads * qkv_dim];
|
||||
memset(out, 0, qkv_dim * sizeof(out[0]));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
size_t v_offset = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const T* v = &qkv[v_offset];
|
||||
MulByConstAndAddT(att[pos2], v, out, kQKVDim);
|
||||
MulByConstAndAddT(att[pos2], v, out, qkv_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T, typename TConfig>
|
||||
void ApplyLayer(const CompressedLayer<TConfig>& weights,
|
||||
ForwardLayer<T, TConfig>& activations, size_t num_tokens,
|
||||
T* output) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim));
|
||||
template <typename T>
|
||||
void ApplyLayer(const LayerWeightsPtrs<T>& weights,
|
||||
ForwardLayer<T>& activations, size_t num_tokens, T* output) {
|
||||
const LayerConfig& layer_config = weights.layer_config;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t seq_len = activations.input.Rows();
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
|
||||
|
||||
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
|
||||
activations.pre_att_rms_out.data(), kModelDim, num_tokens);
|
||||
activations.pre_att_rms_out.data(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
|
||||
activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim,
|
||||
num_tokens);
|
||||
activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * kQKVDim, kQKVDim, pos);
|
||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= heads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * kQKVDim);
|
||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
||||
MulByConstT(query_scale, qkv, heads * qkv_dim);
|
||||
}
|
||||
|
||||
MaskedAttention(activations.qkv.data(), activations.att.data(),
|
||||
num_tokens, kHeads, kQKVDim, kSeqLen);
|
||||
MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
|
||||
heads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen);
|
||||
MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
|
||||
|
||||
MixByAttention(activations.qkv.data(), activations.att.data(),
|
||||
activations.att_out.data(), num_tokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
activations.att_out.data(), num_tokens, heads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
|
||||
activations.attention_out.data(), kHeads, kModelDim, kQKVDim,
|
||||
activations.attention_out.data(), heads, model_dim, qkv_dim,
|
||||
num_tokens);
|
||||
|
||||
AddFromT(activations.input.data(), activations.attention_out.data(),
|
||||
num_tokens * kModelDim);
|
||||
num_tokens * model_dim);
|
||||
|
||||
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens);
|
||||
activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
|
||||
activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim,
|
||||
activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
|
||||
kFFHiddenDim, num_tokens);
|
||||
ff_hidden_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(),
|
||||
output, kModelDim, kFFHiddenDim, num_tokens);
|
||||
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
|
||||
model_dim, ff_hidden_dim, num_tokens);
|
||||
|
||||
AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim);
|
||||
AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -247,48 +246,47 @@ T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) {
|
|||
return loss * scaling;
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
template <typename T>
|
||||
T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||
const CompressedWeights<TConfig>& weights,
|
||||
ForwardPass<T, TConfig>& forward) {
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
ForwardPass<T>& forward) {
|
||||
const ModelConfig& config = weights.weights_config;
|
||||
const size_t model_dim = config.model_dim;
|
||||
const size_t vocab_size = config.vocab_size;
|
||||
const size_t layers = config.layer_configs.size();
|
||||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(kModelDim);
|
||||
InputEmbedding(weights.embedder_input_embedding.data(), tokens,
|
||||
kEmbScaling, forward.layers[0].input.data(), kModelDim);
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
|
||||
forward.layers[0].input.data(), model_dim);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
T* output = layer + 1 < kLayers ?
|
||||
forward.layers[layer + 1].input.data() :
|
||||
forward.final_layer_output.data();
|
||||
for (size_t layer = 0; layer < layers; ++layer) {
|
||||
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
|
||||
: forward.final_layer_output.data();
|
||||
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
||||
output);
|
||||
}
|
||||
|
||||
RMSNormT(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
forward.final_norm_output.data(), kModelDim, num_tokens);
|
||||
RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
||||
forward.final_norm_output.data(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(),
|
||||
forward.logits.data(), kVocabSize, kModelDim, num_tokens);
|
||||
forward.final_norm_output.data(), forward.logits.data(), vocab_size,
|
||||
model_dim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
Softcap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize,
|
||||
kVocabSize);
|
||||
if (config.final_cap > 0.0f) {
|
||||
Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(forward.probs.data(), forward.logits.data(),
|
||||
num_tokens * kVocabSize * sizeof(forward.logits.At(0)));
|
||||
Softmax(forward.probs.data(), kVocabSize, num_tokens);
|
||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
||||
Softmax(forward.probs.data(), vocab_size, num_tokens);
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize);
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -26,8 +27,10 @@
|
|||
#include "backprop/optimizer.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/threading.h"
|
||||
|
|
@ -45,20 +48,18 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
.training = ModelTraining::GEMMA_IT,
|
||||
.weight = Type::kF32,
|
||||
};
|
||||
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||
info.model, info.weight, pool);
|
||||
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||
info.model, info.weight, pool);
|
||||
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||
info.model, info.weight, pool);
|
||||
ByteStorageT forward =
|
||||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||
ByteStorageT backward =
|
||||
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||
KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16);
|
||||
ModelConfig config = ConfigFromModel(info.model);
|
||||
ModelWeightsStorage grad, grad_m, grad_v;
|
||||
grad.Allocate(info.model, info.weight, pool);
|
||||
grad_m.Allocate(info.model, info.weight, pool);
|
||||
grad_v.Allocate(info.model, info.weight, pool);
|
||||
grad_m.ZeroInit();
|
||||
grad_v.ZeroInit();
|
||||
ForwardPass<float> forward(config), backward(config);
|
||||
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
||||
|
||||
RowVectorBatch<float> inv_timescale =
|
||||
Activations::CreateInvTimescale<ConfigGemmaTiny<float>>();
|
||||
RowVectorBatch<float> inv_timescale = Activations::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, pools);
|
||||
|
||||
|
|
@ -92,14 +93,11 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
reply.begin() + context.size());
|
||||
};
|
||||
|
||||
RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen);
|
||||
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
|
||||
grad_m, pool);
|
||||
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
|
||||
grad_v, pool);
|
||||
gemma.MutableWeights().RandInit(gen);
|
||||
gemma.MutableWeights().AllocAndCopyWithTranspose(pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
LogWeightStats(info.model, info.weight, gemma.Weights());
|
||||
gemma.MutableWeights().LogWeightStats();
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
const float alpha = 0.001f;
|
||||
|
|
@ -113,29 +111,29 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
|
||||
grad, pool);
|
||||
grad.ZeroInit();
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
Prompt prompt = training_task.Sample(sgen);
|
||||
total_loss += CrossEntropyLossForwardPass(
|
||||
info.model, prompt, gemma.Weights(), forward, inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
|
||||
grad, backward, inv_timescale, pool);
|
||||
CallForModelAndWeight<ReshapeCompressedWeights>(
|
||||
info.model, info.weight, gemma.MutableWeights(), pool);
|
||||
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
|
||||
inv_timescale, pool);
|
||||
CrossEntropyLossBackwardPass(
|
||||
prompt, *gemma.Weights().GetWeightsOfType<float>(), forward,
|
||||
*grad.GetWeightsOfType<float>(), backward, inv_timescale, pool);
|
||||
gemma.MutableWeights().CopyWithTranspose(pool);
|
||||
num_ok += verify(prompt) ? 1 : 0;
|
||||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon,
|
||||
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
|
||||
AdamUpdate(info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
||||
gemma.Weights(), grad_m, grad_v, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
printf("Batch gradient:\n");
|
||||
LogWeightStats(info.model, info.weight, grad);
|
||||
grad.LogWeightStats();
|
||||
}
|
||||
if (total_loss < 0.5f) {
|
||||
break;
|
||||
|
|
@ -143,7 +141,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
LogWeightStats(info.model, info.weight, gemma.Weights());
|
||||
gemma.MutableWeights().LogWeightStats();
|
||||
EXPECT_LT(steps, 300);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
#include "backprop/optimizer.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
|
|
@ -30,37 +29,6 @@ namespace gcpp {
|
|||
|
||||
namespace {
|
||||
|
||||
class WeightInitializer {
|
||||
public:
|
||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
float* data = tensors[0]->data<float>();
|
||||
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
|
||||
data[i] = dist_(gen_);
|
||||
}
|
||||
tensors[0]->set_scale(1.0f);
|
||||
}
|
||||
|
||||
private:
|
||||
std::normal_distribution<float> dist_;
|
||||
std::mt19937& gen_;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct RandInitWeightsT {
|
||||
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) const {
|
||||
auto& weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
// TODO(szabadka) Use the same weight initialization method as in the python
|
||||
// version.
|
||||
WeightInitializer init(gen);
|
||||
CompressedWeights<TConfig>::ForEachTensor({&weights},
|
||||
ForEachType::kLoadNoToc, init);
|
||||
}
|
||||
};
|
||||
|
||||
class AdamUpdater {
|
||||
public:
|
||||
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
|
||||
|
|
@ -97,42 +65,31 @@ class AdamUpdater {
|
|||
float epsilon_;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct AdamUpdateT {
|
||||
void operator()(const ByteStorageT& grad_u8, float alpha, float beta1,
|
||||
void AdamUpdate(ModelWeightsPtrs<float>* grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
|
||||
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
|
||||
using TWeights = CompressedWeights<TConfig>;
|
||||
auto& grad = *reinterpret_cast<TWeights*>(grad_u8.get());
|
||||
auto& weights = *reinterpret_cast<TWeights*>(weights_u8.get());
|
||||
auto& grad_m = *reinterpret_cast<TWeights*>(grad_m_u8.get());
|
||||
auto& grad_v = *reinterpret_cast<TWeights*>(grad_v_u8.get());
|
||||
ModelWeightsPtrs<float>* weights,
|
||||
ModelWeightsPtrs<float>* grad_m,
|
||||
ModelWeightsPtrs<float>* grad_v, hwy::ThreadPool& pool) {
|
||||
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
||||
TWeights::ForEachTensor(
|
||||
{&grad, &weights, &grad_m, &grad_v}, ForEachType::kLoadNoToc,
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{grad, weights, grad_m, grad_v}, ForEachType::kLoadNoToc,
|
||||
[&updater](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
updater(name, *tensors[0], *tensors[1], *tensors[2], *tensors[3]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void RandInitWeights(Model model_type, Type weight_type,
|
||||
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) {
|
||||
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
|
||||
float beta1, float beta2, float epsilon, size_t t,
|
||||
const ModelWeightsStorage& weights,
|
||||
const ModelWeightsStorage& grad_m,
|
||||
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(weight_type == Type::kF32);
|
||||
CallForModel<float, RandInitWeightsT>(model_type, weights, pool, gen);
|
||||
}
|
||||
|
||||
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(weight_type == Type::kF32);
|
||||
CallForModel<float, AdamUpdateT>(model_type, grad, alpha, beta1, beta2,
|
||||
epsilon, t, weights, grad_m, grad_v, pool);
|
||||
AdamUpdate(grad.GetWeightsOfType<float>(), alpha, beta1, beta2, epsilon, t,
|
||||
weights.GetWeightsOfType<float>(),
|
||||
grad_m.GetWeightsOfType<float>(), grad_v.GetWeightsOfType<float>(),
|
||||
pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,22 +16,17 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "util/allocator.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void RandInitWeights(Model model_type, Type weight_type,
|
||||
const ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen);
|
||||
|
||||
void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad,
|
||||
float alpha, float beta1, float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
|
||||
void AdamUpdate(Type weight_type, const ModelWeightsStorage& grad, float alpha,
|
||||
float beta1, float beta2, float epsilon, size_t t,
|
||||
const ModelWeightsStorage& weights,
|
||||
const ModelWeightsStorage& grad_m,
|
||||
const ModelWeightsStorage& grad_v, hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -21,11 +21,12 @@
|
|||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -39,8 +40,8 @@ void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
|||
}
|
||||
|
||||
// TODO: make a member of Layer<T>.
|
||||
template <typename T, typename TConfig>
|
||||
void RandInit(CompressedLayer<TConfig>& w, T stddev, std::mt19937& gen) {
|
||||
template <typename T>
|
||||
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||
RandInit(w.pre_attention_norm_scale, stddev, gen);
|
||||
RandInit(w.attn_vec_einsum_w, stddev, gen);
|
||||
RandInit(w.qkv_einsum_w, stddev, gen);
|
||||
|
|
@ -49,9 +50,9 @@ void RandInit(CompressedLayer<TConfig>& w, T stddev, std::mt19937& gen) {
|
|||
RandInit(w.linear_w, stddev, gen);
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
void RandInit(CompressedWeights<TConfig>& w, T stddev, std::mt19937& gen) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
template <typename T>
|
||||
void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||
const size_t kLayers = w.c_layers.size();
|
||||
RandInit(w.embedder_input_embedding, stddev, gen);
|
||||
RandInit(w.final_norm_scale, stddev, gen);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
|
|
@ -66,9 +67,8 @@ void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename TConfig, typename UConfig>
|
||||
void Complexify(const CompressedLayer<TConfig>& w,
|
||||
CompressedLayer<UConfig>& c_w) {
|
||||
template <typename T, typename U>
|
||||
void Complexify(const LayerWeightsPtrs<T>& w, LayerWeightsPtrs<U>& c_w) {
|
||||
Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale);
|
||||
Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w);
|
||||
Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w);
|
||||
|
|
@ -77,10 +77,9 @@ void Complexify(const CompressedLayer<TConfig>& w,
|
|||
Complexify(w.linear_w, c_w.linear_w);
|
||||
}
|
||||
|
||||
template <typename TConfig, typename UConfig>
|
||||
void Complexify(const CompressedWeights<TConfig>& w,
|
||||
CompressedWeights<UConfig>& c_w) {
|
||||
static constexpr size_t kLayers = TConfig::kLayers;
|
||||
template <typename T, typename U>
|
||||
void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
|
||||
const size_t kLayers = w.c_layers.size();
|
||||
Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding);
|
||||
Complexify(w.final_norm_scale, c_w.final_norm_scale);
|
||||
for (size_t i = 0; i < kLayers; ++i) {
|
||||
|
|
@ -88,26 +87,27 @@ void Complexify(const CompressedWeights<TConfig>& w,
|
|||
}
|
||||
}
|
||||
|
||||
// Owns weights and provides access to TConfig.
|
||||
template <typename TConfig>
|
||||
// Somewhat duplicates ModelWeightsStorage, but that has neither double nor
|
||||
// complex types allowed and it would cause code bloat to add them there.
|
||||
template <typename T>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
WeightsWrapper()
|
||||
: pool_(0),
|
||||
data_(AllocateCompressedWeights<TConfig>()(pool_)),
|
||||
weights_(reinterpret_cast<CompressedWeights<TConfig>*>(data_.get())) {}
|
||||
explicit WeightsWrapper(const ModelConfig& config)
|
||||
: pool_(0), weights_(config, pool_) {
|
||||
weights_.Allocate(data_, pool_);
|
||||
}
|
||||
|
||||
const CompressedWeights<TConfig>& get() const { return *weights_; }
|
||||
CompressedWeights<TConfig>& get() { return *weights_; }
|
||||
void ZeroInit() { weights_->ZeroInit(); }
|
||||
void CopyFrom(const WeightsWrapper<TConfig>& other) {
|
||||
get().CopyFrom(other.get());
|
||||
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
||||
ModelWeightsPtrs<T>& get() { return weights_; }
|
||||
void ZeroInit() { weights_.ZeroInit(); }
|
||||
void CopyFrom(const WeightsWrapper<T>& other) {
|
||||
weights_.CopyFrom(other.weights_);
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
ByteStorageT data_;
|
||||
CompressedWeights<TConfig>* weights_;
|
||||
std::vector<MatStorage> data_;
|
||||
ModelWeightsPtrs<T> weights_;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
|
|
@ -173,9 +173,9 @@ void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<double>>& x,
|
|||
TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line);
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig, typename UConfig, typename FUNC>
|
||||
void TestGradient(const CompressedLayer<TConfig>& grad,
|
||||
CompressedLayer<UConfig>& c_weights, FUNC func, T max_err) {
|
||||
template <typename T, typename U, typename FUNC>
|
||||
void TestGradient(const LayerWeightsPtrs<T>& grad,
|
||||
LayerWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
|
||||
TestGradient(grad.pre_attention_norm_scale,
|
||||
c_weights.pre_attention_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
|
|
@ -191,15 +191,15 @@ void TestGradient(const CompressedLayer<TConfig>& grad,
|
|||
func, max_err, max_err, __LINE__);
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig, typename UConfig, typename FUNC>
|
||||
void TestGradient(const CompressedWeights<TConfig>& grad,
|
||||
CompressedWeights<UConfig>& c_weights, FUNC func, T max_err) {
|
||||
template <typename T, typename U, typename FUNC>
|
||||
void TestGradient(const ModelWeightsPtrs<T>& grad,
|
||||
ModelWeightsPtrs<U>& c_weights, FUNC func, T max_err) {
|
||||
TestGradient(grad.embedder_input_embedding,
|
||||
c_weights.embedder_input_embedding,
|
||||
func, 2 * max_err, max_err, __LINE__);
|
||||
TestGradient(grad.final_norm_scale, c_weights.final_norm_scale,
|
||||
func, max_err, max_err, __LINE__);
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
for (size_t i = 0; i < grad.c_layers.size(); ++i) {
|
||||
TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -276,6 +275,7 @@ BlobError BlobReader::ReadAll(hwy::ThreadPool& pool) {
|
|||
[pfile, &requests, &err](uint64_t i, size_t /*thread*/) {
|
||||
if (!pfile->Read(requests[i].offset, requests[i].size,
|
||||
requests[i].data)) {
|
||||
fprintf(stderr, "Failed to read blob %zu\n", i);
|
||||
err.test_and_set();
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -102,8 +102,8 @@ class CompressedArray {
|
|||
class MatPtr {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtr(const std::string& name, const std::string& type, size_t element_size,
|
||||
size_t rows, size_t cols)
|
||||
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
||||
size_t cols)
|
||||
: name_(name),
|
||||
type_(type),
|
||||
element_size_(element_size),
|
||||
|
|
@ -129,7 +129,7 @@ class MatPtr {
|
|||
MatPtr(const hwy::uint128_t& key0, const hwy::uint128_t& key1,
|
||||
const hwy::uint128_t& key2, const hwy::uint128_t& key3)
|
||||
: name_(StringFromKey(key0)),
|
||||
type_(StringFromKey(key1)),
|
||||
type_(static_cast<Type>(key1.lo)),
|
||||
element_size_(key2.hi),
|
||||
num_elements_(key2.lo),
|
||||
rows_(key3.lo),
|
||||
|
|
@ -138,7 +138,7 @@ class MatPtr {
|
|||
// Adds the contents entry to the table of contents.
|
||||
void AddToToc(std::vector<hwy::uint128_t>& toc) const {
|
||||
toc.push_back(MakeKey(name_.c_str()));
|
||||
toc.push_back(MakeKey(type_.c_str()));
|
||||
toc.push_back({static_cast<uint64_t>(type_), 0});
|
||||
toc.push_back({num_elements_, element_size_});
|
||||
toc.push_back({rows_, cols_});
|
||||
}
|
||||
|
|
@ -167,7 +167,7 @@ class MatPtr {
|
|||
void SetName(const std::string& name) { name_ = name; }
|
||||
|
||||
// Returns the type of the blob.
|
||||
const std::string& Type() const { return type_; }
|
||||
Type GetType() const { return type_; }
|
||||
|
||||
// Returns the size of each element in bytes.
|
||||
size_t ElementSize() const { return element_size_; }
|
||||
|
|
@ -219,8 +219,8 @@ class MatPtr {
|
|||
protected:
|
||||
// Arbitrary name for the array of preferably <= 16 characters.
|
||||
std::string name_;
|
||||
// Should be the result of TypeName<T> for CallUpcasted() to work.
|
||||
std::string type_;
|
||||
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
||||
Type type_;
|
||||
// sizeof(T)
|
||||
size_t element_size_ = 0;
|
||||
// Number of elements in the array.
|
||||
|
|
@ -247,7 +247,7 @@ class MatPtrT : public MatPtr {
|
|||
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtrT(const std::string& name, size_t rows, size_t cols)
|
||||
: MatPtr(name, TypeName<MatT>(), sizeof(MatT), rows, cols) {}
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
|
|
@ -330,17 +330,20 @@ class MatPtrT : public MatPtr {
|
|||
|
||||
template <class FuncT, typename... TArgs>
|
||||
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
||||
if (type_ == TypeName<float>()) {
|
||||
if (type_ == TypeEnum<float>()) {
|
||||
return func(dynamic_cast<MatPtrT<float>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeName<BF16>()) {
|
||||
} else if (type_ == TypeEnum<BF16>()) {
|
||||
return func(dynamic_cast<MatPtrT<BF16>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeName<SfpStream>()) {
|
||||
} else if (type_ == TypeEnum<SfpStream>()) {
|
||||
return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<NuqStream>()) {
|
||||
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Type %s unknown.", type_.c_str());
|
||||
HWY_ABORT("Type %d unknown.", type_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -563,9 +566,10 @@ class CacheLoader {
|
|||
}
|
||||
|
||||
// Returns whether all tensors are successfully loaded from cache.
|
||||
bool ReadAll(hwy::ThreadPool& pool, std::vector<MatStorage>& model_memory) {
|
||||
BlobError ReadAll(hwy::ThreadPool& pool,
|
||||
std::vector<MatStorage>& model_memory) {
|
||||
// reader_ invalid or any Enqueue failed
|
||||
if (err_ != 0) return false;
|
||||
if (err_ != 0) return err_;
|
||||
// Setup the model_memory.
|
||||
for (int b = 0; b < model_toc_.size(); ++b) {
|
||||
const std::string& file_key = file_keys_[b];
|
||||
|
|
@ -574,12 +578,12 @@ class CacheLoader {
|
|||
const MatPtr* toc_blob = file_toc_.Get(file_key);
|
||||
if (toc_blob == nullptr) {
|
||||
fprintf(stderr, "Blob %s not found in TOC\n", file_key.c_str());
|
||||
return false;
|
||||
return __LINE__;
|
||||
}
|
||||
if (toc_blob->Rows() != blob->Rows() ||
|
||||
toc_blob->Cols() != blob->Cols()) {
|
||||
fprintf(stderr, "Blob %s has size mismatch TOC\n", file_key.c_str());
|
||||
return false;
|
||||
return __LINE__;
|
||||
}
|
||||
MatStorage toc_blob_array(*toc_blob);
|
||||
model_memory.push_back(std::move(toc_blob_array));
|
||||
|
|
@ -603,17 +607,10 @@ class CacheLoader {
|
|||
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
||||
blob.Name().c_str(), err_, blob.Rows(), blob.Cols(),
|
||||
blob.ElementSize());
|
||||
return false;
|
||||
return err_;
|
||||
}
|
||||
}
|
||||
|
||||
err_ = reader_.ReadAll(pool);
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr, "Failed to read all tensors (error %d)\n", err_);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return reader_.ReadAll(pool);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
|
@ -150,29 +151,22 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <class Configs>
|
||||
template <typename T>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
hwy::ThreadPool& pool) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
||||
compressed_weights_path.path.c_str());
|
||||
|
||||
using CConfig = typename Configs::c;
|
||||
using UCConfig = typename Configs::uc;
|
||||
// Allocate compressed weights.
|
||||
using CWeights = CompressedWeights<CConfig>;
|
||||
ByteStorageT c_weights_u8 = AllocateCompressedWeights<CConfig>()(pool);
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||
|
||||
// Allocate uncompressed weights.
|
||||
using UCWeights = CompressedWeights<UCConfig>;
|
||||
ByteStorageT uc_weights_u8 = AllocateCompressedWeights<UCConfig>()(pool);
|
||||
UCWeights* uc_weights = reinterpret_cast<UCWeights*>(uc_weights_u8.get());
|
||||
|
||||
ModelConfig config = ConfigFromModel(model_type);
|
||||
std::vector<MatStorage> model_storage;
|
||||
ModelWeightsPtrs<T> c_weights(config, pool);
|
||||
c_weights.Allocate(model_storage, pool);
|
||||
ModelWeightsPtrs<float> uc_weights(config, pool);
|
||||
uc_weights.Allocate(model_storage, pool);
|
||||
// Get uncompressed weights, compress, and store.
|
||||
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
|
||||
if (fptr == nullptr) {
|
||||
|
|
@ -181,22 +175,22 @@ void CompressWeights(const Path& weights_path,
|
|||
}
|
||||
bool ok = true;
|
||||
uint64_t total_size = 0;
|
||||
CompressedWeights<UCConfig>::ForEachTensor(
|
||||
{uc_weights}, ForEachType::kLoadNoToc,
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{&uc_weights}, ForEachType::kLoadNoToc,
|
||||
[&](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
fprintf(stderr, "Loading Parameters (size %zu): %s\n",
|
||||
tensors[0]->SizeBytes(), name);
|
||||
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
||||
total_size += tensors[0]->SizeBytes();
|
||||
});
|
||||
const bool scale_for_compression = UCConfig::kNumTensorScales > 0;
|
||||
const bool scale_for_compression = config.num_tensor_scales > 0;
|
||||
std::vector<float> scales;
|
||||
if (scale_for_compression) {
|
||||
uc_weights->GetOrApplyScales(scales);
|
||||
uc_weights.GetOrApplyScales(scales);
|
||||
}
|
||||
Compressor compressor(pool);
|
||||
CompressedWeights<CConfig>::ForEachTensor(
|
||||
{reinterpret_cast<CompressedWeights<CConfig>*>(uc_weights), c_weights},
|
||||
ModelWeightsPtrs<T>::ForEachTensor(
|
||||
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
||||
ForEachType::kLoadNoToc,
|
||||
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[1]->CallUpcasted(
|
||||
|
|
@ -221,9 +215,26 @@ void Run(Args& args) {
|
|||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Type weight_type = args.WeightType();
|
||||
GEMMA_EXPORT_AND_DISPATCH(
|
||||
model_type, weight_type, CompressWeights,
|
||||
(args.weights, args.compressed_weights, model_type, weight_type, pool));
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
||||
(args.weights, args.compressed_weights, model_type, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -32,11 +32,6 @@ namespace gcpp {
|
|||
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsF32() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||
}
|
||||
|
||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
||||
|
|
@ -179,29 +174,67 @@ struct NuqStream {
|
|||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsF32() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsBF16() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsSfpStream() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, SfpStream>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsNuqStream() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>();
|
||||
}
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||
|
||||
// Tensor types for loading weights. Note that not all types are supported as
|
||||
// weights for a model, but can be used for other purposes, such as types for
|
||||
// ModelWeightsPtrs. When adding a new type that is supported, also
|
||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "c64", "u128"};
|
||||
|
||||
// Returns a Type enum for the type of the template parameter.
|
||||
template <typename PackedT>
|
||||
const char* TypeName() {
|
||||
Type TypeEnum() {
|
||||
using Packed = hwy::RemoveCvRef<PackedT>;
|
||||
if constexpr (hwy::IsSame<Packed, float>()) {
|
||||
return "f32";
|
||||
return Type::kF32;
|
||||
} else if constexpr (hwy::IsSame<Packed, BF16>()) {
|
||||
return "b16";
|
||||
return Type::kBF16;
|
||||
} else if constexpr (hwy::IsSame<Packed, SfpStream>()) {
|
||||
return "sfp";
|
||||
return Type::kSFP;
|
||||
} else if constexpr (hwy::IsSame<Packed, NuqStream>()) {
|
||||
return "nuq";
|
||||
return Type::kNUQ;
|
||||
} else if constexpr (hwy::IsSame<Packed, double>()) {
|
||||
return "f64";
|
||||
return Type::kF64;
|
||||
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
|
||||
return "c64";
|
||||
return Type::kC64;
|
||||
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
|
||||
return "u128";
|
||||
return Type::kU128;
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
return "unknown";
|
||||
return Type::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a string name for the type of the template parameter.
|
||||
template <typename PackedT>
|
||||
const char* TypeName() {
|
||||
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsCompressed() {
|
||||
return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>();
|
||||
|
|
|
|||
|
|
@ -128,8 +128,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache = KVCache::Create(
|
||||
env.GetModel()->Info().model, env.MutableConfig().prefill_tbatch_size);
|
||||
KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
|
||||
env.MutableConfig().prefill_tbatch_size);
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
total_entropy += entropy;
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
model_ = AllocateGemma(mutable_loader, pools_);
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] =
|
||||
KVCache::Create(model_->Info().model, inference.prefill_tbatch_size);
|
||||
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
}
|
||||
InitGenerator(inference, gen_);
|
||||
runtime_config_ = {
|
||||
|
|
@ -163,7 +163,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
}
|
||||
for (size_t i = 1; i < num_queries; ++i) {
|
||||
if (kv_caches_[i].seq_len == 0) {
|
||||
kv_caches_[i] = KVCache::Create(model_->Info().model,
|
||||
kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
|
||||
runtime_config_.prefill_tbatch_size);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,8 +103,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
|
|||
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
|
||||
|
||||
// TWeight is unused, but we have to pass it to Config*.
|
||||
const int vocab_size =
|
||||
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
|
||||
const int vocab_size = gemma.GetModelConfig().vocab_size;
|
||||
float cross_entropy = std::log(vocab_size); // first token
|
||||
size_t pos = 1;
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@
|
|||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/app.h" // LoaderArgs
|
||||
|
|
@ -58,7 +57,8 @@ int main(int argc, char** argv) {
|
|||
gcpp::PerClusterPools pools(app.max_clusters, app.max_threads, app.pin);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(loader.Info().model, inference.prefill_tbatch_size);
|
||||
gcpp::KVCache::Create(model.GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <cmath>
|
||||
|
||||
#include "compression/shared.h" // BF16
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/threading.h"
|
||||
|
|
@ -30,6 +31,12 @@
|
|||
namespace gcpp {
|
||||
|
||||
struct Activations {
|
||||
explicit Activations(const ModelConfig& config)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
seq_len(config.seq_len),
|
||||
cache_pos_size(config.CachePosSize()) {}
|
||||
|
||||
RowVectorBatch<float> x; // input
|
||||
RowVectorBatch<float> q; // query, also KV if MHA.
|
||||
RowVectorBatch<float> logits;
|
||||
|
|
@ -58,23 +65,24 @@ struct Activations {
|
|||
|
||||
MatMulEnv env;
|
||||
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
// And the config.
|
||||
const ModelConfig& weights_config;
|
||||
const LayerConfig& layer_config;
|
||||
size_t seq_len;
|
||||
size_t cache_pos_size = 0;
|
||||
|
||||
// Multi-Head Attention?
|
||||
template <class TConfig>
|
||||
static constexpr bool IsMHA() {
|
||||
return TConfig::kHeads == TConfig::kKVHeads;
|
||||
}
|
||||
bool IsMHA() const { return layer_config.heads == layer_config.kv_heads; }
|
||||
|
||||
// Stride between subsequent queries. Each of Q, K, V are of length kQKVDim,
|
||||
// but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V.
|
||||
template <class TConfig>
|
||||
static constexpr size_t QStride() {
|
||||
return TConfig::kQKVDim * (IsMHA<TConfig>() ? 3 : 1);
|
||||
}
|
||||
size_t QStride() const { return layer_config.qkv_dim * (IsMHA() ? 3 : 1); }
|
||||
|
||||
template <class TConfig>
|
||||
static RowVectorBatch<float> CreateInvTimescale() {
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim;
|
||||
static RowVectorBatch<float> CreateInvTimescale(size_t qkv_dim,
|
||||
PostQKType post_qk) {
|
||||
const size_t rope_dim =
|
||||
post_qk == PostQKType::HalfRope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(1, rope_dim / 2);
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const float freq_exponents =
|
||||
|
|
@ -86,40 +94,38 @@ struct Activations {
|
|||
return inv_timescale;
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
void Allocate(size_t batch_size, PerClusterPools& pools) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
constexpr size_t kHeads = TConfig::kHeads;
|
||||
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
constexpr size_t kSeqLen = TConfig::kSeqLen;
|
||||
constexpr size_t kGriffinLayers = TConfig::kGriffinLayers;
|
||||
post_qk = layer_config.post_qk;
|
||||
const size_t model_dim = weights_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
const size_t vocab_size = weights_config.vocab_size;
|
||||
|
||||
x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
q = RowVectorBatch<float>(batch_size, kHeads * QStride<TConfig>());
|
||||
if constexpr (kVocabSize > 0) {
|
||||
logits = RowVectorBatch<float>(batch_size, kVocabSize);
|
||||
x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
q = RowVectorBatch<float>(batch_size, layer_config.heads * QStride());
|
||||
if (vocab_size > 0) {
|
||||
logits = RowVectorBatch<float>(batch_size, vocab_size);
|
||||
}
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen);
|
||||
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
|
||||
att_sums = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
pre_att_rms_out = RowVectorBatch<float>(batch_size, model_dim);
|
||||
att = RowVectorBatch<float>(batch_size,
|
||||
layer_config.heads * weights_config.seq_len);
|
||||
att_out = RowVectorBatch<float>(batch_size,
|
||||
layer_config.heads * layer_config.qkv_dim);
|
||||
att_sums = RowVectorBatch<float>(batch_size, model_dim);
|
||||
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, kModelDim);
|
||||
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, model_dim);
|
||||
C1 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
|
||||
C2 = RowVectorBatch<float>(batch_size, ff_hidden_dim);
|
||||
ffw_out = RowVectorBatch<float>(batch_size, model_dim);
|
||||
|
||||
if constexpr (kGriffinLayers > 0) {
|
||||
griffin_x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
griffin_y = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
griffin_gate_x = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
griffin_multiplier = RowVectorBatch<float>(batch_size, kModelDim);
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
griffin_x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_y = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_gate_x = RowVectorBatch<float>(batch_size, model_dim);
|
||||
griffin_multiplier = RowVectorBatch<float>(batch_size, model_dim);
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale<TConfig>();
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk);
|
||||
|
||||
env = MatMulEnv(pools);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "gemma/common.h"
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
|
|
@ -23,6 +24,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -101,8 +103,6 @@ const char* ModelString(Model model, ModelTraining training) {
|
|||
static_cast<int>(training));
|
||||
}
|
||||
|
||||
constexpr const char* kTypeStrings[] = {"f32", "bf16", "sfp"};
|
||||
|
||||
const char* StringFromType(Type type) {
|
||||
return kTypeStrings[static_cast<size_t>(type)];
|
||||
}
|
||||
|
|
@ -141,4 +141,19 @@ void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
|||
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
||||
}
|
||||
}
|
||||
|
||||
float EmbeddingScaling(size_t model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
sqrtf(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
float ChooseQueryScale(const ModelConfig& config) {
|
||||
if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f / sqrtf(static_cast<float>(config.model_dim /
|
||||
config.layer_configs[0].heads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / sqrtf(static_cast<float>(config.layer_configs[0].qkv_dim));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
212
gemma/common.h
212
gemma/common.h
|
|
@ -16,37 +16,15 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Model variants: see configs.h for details. When adding a new one, also
|
||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||
enum class Model {
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GEMMA2_9B,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA2_2B,
|
||||
PALIGEMMA_224,
|
||||
};
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||
|
||||
// Tensor types for loading weights. When adding a new one, also
|
||||
// update GEMMA_FOREACH* and Call* below, and add instantiations/*.cc.
|
||||
enum class Type { kF32, kBF16, kSFP };
|
||||
|
||||
// TODO(janwas): merge with functions below.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
|
|
@ -66,198 +44,12 @@ const char* StringFromType(Type type);
|
|||
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
|
||||
|
||||
// Returns the return value of FuncT<Config*<TWeight>>().operator()(args), where
|
||||
// Config* is selected via `model`. Typically called by CallForModelAndWeight,
|
||||
// but can also be called directly when FuncT does not actually use TWeight.
|
||||
//
|
||||
// Note that a T prefix indicates a concrete type template argument, whereas a
|
||||
// T suffix indicates the argument is itself a template.
|
||||
//
|
||||
// `FuncT` must be a functor because function templates cannot be passed as a
|
||||
// template template argument, and we prefer to avoid the overhead of
|
||||
// std::function.
|
||||
template <typename TWeight, template <typename TConfig> class FuncT,
|
||||
typename... TArgs>
|
||||
decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_TINY:
|
||||
return FuncT<ConfigGemmaTiny<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_2B:
|
||||
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_7B:
|
||||
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_9B:
|
||||
return FuncT<ConfigGemma2_9B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_27B:
|
||||
return FuncT<ConfigGemma2_27B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GRIFFIN_2B:
|
||||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_2B:
|
||||
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::PALIGEMMA_224:
|
||||
return FuncT<ConfigPaliGemma_224<TWeight>>()(
|
||||
std::forward<TArgs>(args)...);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the return value of FuncT<TConfig>().operator()(args),
|
||||
// where `TConfig` is selected based on `model` and `weight`.
|
||||
|
||||
// This makes it easy to extend `Model` or `Type` without updating callers.
|
||||
//
|
||||
// Usage example: LoadWeights is type-erased so that it can be called from other
|
||||
// .cc files. It uses this function to call the appropriate instantiation of a
|
||||
// template functor LoadCompressedWeightsT<TConfig>.
|
||||
template <template <typename TConfig> class FuncT, typename... TArgs>
|
||||
decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||
TArgs&&... args) {
|
||||
switch (weight) {
|
||||
case Type::kF32:
|
||||
return CallForModel<float, FuncT, TArgs...>( //
|
||||
model, std::forward<TArgs>(args)...);
|
||||
case Type::kBF16:
|
||||
return CallForModel<BF16, FuncT, TArgs...>(model,
|
||||
std::forward<TArgs>(args)...);
|
||||
case Type::kSFP:
|
||||
return CallForModel<SfpStream, FuncT, TArgs...>(
|
||||
model, std::forward<TArgs>(args)...);
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unknown.", static_cast<int>(weight));
|
||||
}
|
||||
}
|
||||
|
||||
#define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \
|
||||
X(CONFIGT, float) \
|
||||
X(CONFIGT, BF16) \
|
||||
X(CONFIGT, SfpStream)
|
||||
|
||||
#define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigPaliGemma_224) \
|
||||
static_assert(true, "Allow trailing ;")
|
||||
|
||||
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||
// calls FUNC<ConfigT<TWEIGHT>> where ConfigT is chosen via MODEL enum.
|
||||
#define GEMMA_DISPATCH_MODEL(MODEL, TWEIGHT, FUNC, ARGS) \
|
||||
switch (MODEL) { \
|
||||
case Model::GEMMA_TINY: { \
|
||||
using CP = ConfigPair<ConfigGemmaTiny<TWEIGHT>, ConfigGemmaTiny<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_2B: { \
|
||||
using CP = ConfigPair<ConfigGemma2B<TWEIGHT>, ConfigGemma2B<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_7B: { \
|
||||
using CP = ConfigPair<ConfigGemma7B<TWEIGHT>, ConfigGemma7B<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GRIFFIN_2B: { \
|
||||
using CP = ConfigPair<ConfigGriffin2B<TWEIGHT>, ConfigGriffin2B<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA2_2B: { \
|
||||
using CP = ConfigPair<ConfigGemma2_2B<TWEIGHT>, ConfigGemma2_2B<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA2_9B: { \
|
||||
using CP = ConfigPair<ConfigGemma2_9B<TWEIGHT>, ConfigGemma2_9B<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA2_27B: { \
|
||||
using CP = \
|
||||
ConfigPair<ConfigGemma2_27B<TWEIGHT>, ConfigGemma2_27B<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::PALIGEMMA_224: { \
|
||||
using CP = ConfigPair<ConfigPaliGemma_224<TWEIGHT>, \
|
||||
ConfigPaliGemma_224<float>>; \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<CP>) ARGS; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||
}
|
||||
|
||||
// Like CallForModelAndWeight, but for SIMD function templates. This is a macro
|
||||
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
|
||||
// normal function argument. MODEL and WEIGHT are enums.
|
||||
// For gemma.cc, we use overloaded extern functions for faster builds. However,
|
||||
// this is still used in compress_weights because its compile time is OK.
|
||||
#define GEMMA_EXPORT_AND_DISPATCH(MODEL, WEIGHT, FUNC, ARGS) \
|
||||
switch (WEIGHT) { \
|
||||
case Type::kF32: \
|
||||
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
|
||||
break; \
|
||||
case Type::kBF16: \
|
||||
GEMMA_DISPATCH_MODEL(MODEL, BF16, FUNC, ARGS); \
|
||||
break; \
|
||||
case Type::kSFP: \
|
||||
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
|
||||
break; \
|
||||
default: \
|
||||
HWY_ABORT("Weight type %d unknown.", static_cast<int>(WEIGHT)); \
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
//
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||
return __builtin_sqrt(x);
|
||||
}
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_SQRT
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||
#endif
|
||||
float EmbeddingScaling(size_t model_dim);
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
// are both constexpr
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
||||
size_t model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(
|
||||
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
template <class TConfig>
|
||||
GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() {
|
||||
if (TConfig::kQueryScale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f /
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim / TConfig::kHeads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
|
||||
}
|
||||
float ChooseQueryScale(const ModelConfig& config);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,246 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static ModelConfig ConfigNoSSM() {
|
||||
ModelConfig config = {.scale_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
|
||||
"gr_lin_y_w", "gr_lin_out_w",
|
||||
"gr_gate_w", "gating_ein", "linear_w"}};
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV1() { return ConfigNoSSM(); }
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV2() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.att_cap = 50.0f;
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2_27B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_27B";
|
||||
config.model = Model::GEMMA2_27B;
|
||||
config.model_dim = 4608;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
|
||||
.heads = 32,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 128,
|
||||
.post_norm = PostNormType::Scale};
|
||||
config.layer_configs = {46, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<46, 2>({4096, 8192});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2_9B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_9B";
|
||||
config.model = Model::GEMMA2_9B;
|
||||
config.model_dim = 3584;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
|
||||
.heads = 16,
|
||||
.kv_heads = 8,
|
||||
.qkv_dim = 256,
|
||||
.post_norm = PostNormType::Scale};
|
||||
config.layer_configs = {42, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<42, 2>({4096, 8192});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2_2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV2();
|
||||
config.model_name = "Gemma2_2B";
|
||||
config.model = Model::GEMMA2_2B;
|
||||
config.model_dim = 2304;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
|
||||
.heads = 8,
|
||||
.kv_heads = 4,
|
||||
.qkv_dim = 256,
|
||||
.post_norm = PostNormType::Scale};
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes =
|
||||
RepeatedAttentionWindowSizes<26, 2>({4096, 8192});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma7B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma7B";
|
||||
config.model = Model::GEMMA_7B;
|
||||
config.model_dim = 3072;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = gcpp::kSeqLen;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
|
||||
.heads = 16,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 256,
|
||||
};
|
||||
config.layer_configs = {28, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma2B() {
|
||||
ModelConfig config = ConfigBaseGemmaV1();
|
||||
config.model_name = "Gemma2B";
|
||||
config.model = Model::GEMMA_2B;
|
||||
config.model_dim = 2048;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = gcpp::kSeqLen;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
|
||||
.heads = 8,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 256,
|
||||
};
|
||||
config.layer_configs = {18, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen);
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemmaTiny() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "GemmaTiny";
|
||||
config.model = Model::GEMMA_TINY;
|
||||
config.model_dim = 128;
|
||||
config.vocab_size = 64;
|
||||
config.seq_len = 32;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 256,
|
||||
.heads = 4,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 16,
|
||||
};
|
||||
config.layer_configs = {3, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 30.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGriffin2B() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "Griffin2B";
|
||||
config.model = Model::GRIFFIN_2B;
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = 2048;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.griffin_dim = config.model_dim,
|
||||
.ff_hidden_dim = 7680,
|
||||
.heads = 10,
|
||||
.kv_heads = 1,
|
||||
.qkv_dim = 256,
|
||||
.conv1d_width = 4,
|
||||
.ff_biases = true,
|
||||
.softmax_attn_output_biases = true,
|
||||
.type = LayerAttentionType::kGriffinRecurrentBlock,
|
||||
.activation = ActivationType::Gelu,
|
||||
.post_qk = PostQKType::Rope,
|
||||
};
|
||||
config.layer_configs = {26, layer_config};
|
||||
for (size_t i = 2; i < config.layer_configs.size(); i += 3) {
|
||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||
config.layer_configs[i].griffin_dim = 0;
|
||||
}
|
||||
config.num_tensor_scales = 140;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
|
||||
config.use_local_attention = true;
|
||||
// This is required for optimize_test to pass.
|
||||
config.final_cap = 0.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigPaliGemma_224() {
|
||||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
config.vit_model_dim = 1152;
|
||||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||
config.vit_seq_len = 16 * 16;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.vit_model_dim,
|
||||
.ff_hidden_dim = 4304,
|
||||
.heads = 16,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 72,
|
||||
.type = LayerAttentionType::kVit,
|
||||
.patch_width = 14,
|
||||
.image_size = 224,
|
||||
};
|
||||
config.vit_layer_configs = {27, layer_config};
|
||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||
return config;
|
||||
}
|
||||
|
||||
ModelConfig ConfigFromModel(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return ConfigGemma2B();
|
||||
case Model::GEMMA_7B:
|
||||
return ConfigGemma7B();
|
||||
case Model::GEMMA2_2B:
|
||||
return ConfigGemma2_2B();
|
||||
case Model::GEMMA2_9B:
|
||||
return ConfigGemma2_9B();
|
||||
case Model::GEMMA2_27B:
|
||||
return ConfigGemma2_27B();
|
||||
case Model::GRIFFIN_2B:
|
||||
return ConfigGriffin2B();
|
||||
case Model::GEMMA_TINY:
|
||||
return ConfigGemmaTiny();
|
||||
case Model::PALIGEMMA_224:
|
||||
return ConfigPaliGemma_224();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
411
gemma/configs.h
411
gemma/configs.h
|
|
@ -21,6 +21,9 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/shared.h" // BF16
|
||||
|
||||
|
|
@ -57,6 +60,7 @@ enum class PostNormType {
|
|||
// Post qk projection operation type.
|
||||
enum class PostQKType {
|
||||
Rope,
|
||||
HalfRope,
|
||||
};
|
||||
|
||||
// FFW activation function.
|
||||
|
|
@ -76,358 +80,115 @@ enum class ResidualType {
|
|||
};
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
||||
LayerAttentionType type) {
|
||||
std::array<LayerAttentionType, kNum> config = {};
|
||||
for (LayerAttentionType& l : config) {
|
||||
l = type;
|
||||
}
|
||||
return config;
|
||||
std::vector<LayerAttentionType> FixedLayerConfig(LayerAttentionType type) {
|
||||
return std::vector<LayerAttentionType>(kNum, type);
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
|
||||
size_t window_size) {
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t& l : window_size_configs) {
|
||||
l = window_size;
|
||||
}
|
||||
return window_size_configs;
|
||||
std::vector<size_t> FixedAttentionWindowSizes(size_t window_size) {
|
||||
return std::vector<size_t>(kNum, window_size);
|
||||
}
|
||||
|
||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||
template <size_t kNum, size_t kPatternSize>
|
||||
constexpr std::array<size_t, kNum> RepeatedAttentionWindowSizes(
|
||||
std::vector<size_t> RepeatedAttentionWindowSizes(
|
||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||
static_assert(kNum % kPatternSize == 0,
|
||||
"kNum must be a multiple of kPatternSize");
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
std::vector<size_t> window_size_configs(kNum);
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t NumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
LayerAttentionType type, size_t num) {
|
||||
// Model variants: see configs.cc for details.
|
||||
enum class Model {
|
||||
UNKNOWN,
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GEMMA2_9B,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA2_2B,
|
||||
PALIGEMMA_224,
|
||||
};
|
||||
|
||||
struct LayerConfig {
|
||||
size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; }
|
||||
|
||||
size_t model_dim = 0;
|
||||
size_t griffin_dim = 0;
|
||||
size_t ff_hidden_dim = 0;
|
||||
size_t heads = 0;
|
||||
size_t kv_heads = 0;
|
||||
size_t qkv_dim = 0;
|
||||
size_t conv1d_width = 0;
|
||||
bool ff_biases = false;
|
||||
bool softmax_attn_output_biases = false;
|
||||
PostNormType post_norm = PostNormType::None;
|
||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
ActivationType activation = ActivationType::Gelu;
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
// Dimensions related to image processing.
|
||||
int patch_width = 14;
|
||||
int image_size = 224;
|
||||
};
|
||||
|
||||
struct ModelConfig {
|
||||
size_t CachePosSize() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
}
|
||||
|
||||
size_t NumLayersOfTypeBefore(LayerAttentionType type, size_t num) const {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
if (layers[i] == type) count++;
|
||||
if (layer_configs[i].type == type) ++count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CacheLayerSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
|
||||
size_t NumLayersOfType(LayerAttentionType type) const {
|
||||
return NumLayersOfTypeBefore(type, layer_configs.size());
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CachePosSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
|
||||
size_t NumHeads() const {
|
||||
size_t num_heads = 0;
|
||||
for (const auto& layer_config : layer_configs) {
|
||||
num_heads = std::max(num_heads, layer_config.heads);
|
||||
}
|
||||
return num_heads;
|
||||
}
|
||||
|
||||
std::string model_name;
|
||||
Model model;
|
||||
ModelTraining training;
|
||||
Type weight;
|
||||
size_t model_dim = 0;
|
||||
size_t vit_model_dim = 0;
|
||||
size_t vocab_size = 0;
|
||||
size_t seq_len = 0;
|
||||
size_t vit_seq_len = 0;
|
||||
size_t num_tensor_scales = 0;
|
||||
size_t num_vit_scales = 0;
|
||||
size_t top_k = kTopK;
|
||||
float att_cap = 0.0f;
|
||||
float final_cap = 0.0f;
|
||||
bool absolute_pe = false;
|
||||
bool use_local_attention = false;
|
||||
QueryScaleType query_scale = QueryScaleType::SqrtKeySize;
|
||||
std::vector<LayerConfig> layer_configs;
|
||||
std::vector<size_t> attention_window_sizes;
|
||||
std::vector<LayerConfig> vit_layer_configs;
|
||||
std::unordered_set<std::string> scale_names;
|
||||
int norm_num_groups = 1;
|
||||
int model_family_version = 1;
|
||||
};
|
||||
|
||||
struct ConfigNoVit {
|
||||
struct VitConfig {
|
||||
// Some of these are needed to make the compiler happy when trying to
|
||||
// generate code that will actually never be used.
|
||||
using Weight = float;
|
||||
static constexpr int kLayers = 0;
|
||||
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
|
||||
FixedLayerConfig<0>(LayerAttentionType::kVit);
|
||||
static constexpr int kModelDim = 0;
|
||||
static constexpr int kFFHiddenDim = 0;
|
||||
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
|
||||
static constexpr int kKVHeads = 0;
|
||||
static constexpr int kQKVDim = 0;
|
||||
static constexpr int kSeqLen = 0;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
struct ConfigNoSSM : ConfigNoVit {
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
struct ConfigBaseGemmaV1 : ConfigNoSSM {
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
struct ConfigBaseGemmaV2 : ConfigNoSSM {
|
||||
static constexpr float kAttCap = 50.0f;
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::Scale;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2_27B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
|
||||
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
|
||||
static constexpr int kHeads = 32;
|
||||
static constexpr int kKVHeads = 16;
|
||||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale =
|
||||
QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2_9B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
|
||||
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 8;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma7B : public ConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<28>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2B : public ConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<18>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigPaliGemma_224 : public ConfigGemma2B<TWeight> {
|
||||
// On the LM side, the vocab size is one difference to Gemma1-2B in the
|
||||
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
|
||||
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
|
||||
|
||||
// Sub-config for the Vision-Transformer part.
|
||||
struct VitConfig : public ConfigNoSSM {
|
||||
using Weight = TWeight;
|
||||
// The ViT parts. https://arxiv.org/abs/2305.13035
|
||||
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
|
||||
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
|
||||
FixedLayerConfig<27>(LayerAttentionType::kVit);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kModelDim = 1152;
|
||||
static constexpr int kFFHiddenDim = 4304;
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 72;
|
||||
static constexpr int kSeqLen = 16 * 16; // 256
|
||||
static constexpr bool kFFBiases = true;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
static constexpr int kVocabSize = 0;
|
||||
// Dimensions related to image processing.
|
||||
static constexpr int kPatchWidth = 14;
|
||||
static constexpr int kImageSize = 224;
|
||||
// Necessary constant for the layer configuration.
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
|
||||
FixedLayerConfig<26>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2304;
|
||||
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 4;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 32;
|
||||
static constexpr int kVocabSize = 64;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
FixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<3>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 128;
|
||||
static constexpr int kFFHiddenDim = 256;
|
||||
static constexpr int kHeads = 4;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 16; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGriffin2B : ConfigNoVit {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
static constexpr int kSeqLen = 2048;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
};
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<26>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
static constexpr int kGriffinLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
kLayers);
|
||||
static constexpr int kModelDim = 2560;
|
||||
static constexpr int kFFHiddenDim = 7680;
|
||||
static constexpr int kHeads = 10;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
// No SoftCap.
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
static constexpr bool kFFBiases = true;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = true;
|
||||
static constexpr bool kUseHalfRope = true;
|
||||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
// Returns the config for the given model.
|
||||
ModelConfig ConfigFromModel(Model model);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,445 @@
|
|||
#include "gemma/configs.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<LayerAttentionType, kNum> OldFixedLayerConfig(
|
||||
LayerAttentionType type) {
|
||||
std::array<LayerAttentionType, kNum> config = {};
|
||||
for (LayerAttentionType& l : config) {
|
||||
l = type;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<size_t, kNum> OldFixedAttentionWindowSizes(
|
||||
size_t window_size) {
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t& l : window_size_configs) {
|
||||
l = window_size;
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||
template <size_t kNum, size_t kPatternSize>
|
||||
constexpr std::array<size_t, kNum> OldRepeatedAttentionWindowSizes(
|
||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||
static_assert(kNum % kPatternSize == 0,
|
||||
"kNum must be a multiple of kPatternSize");
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t OldNumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
LayerAttentionType type, size_t num) {
|
||||
size_t count = 0;
|
||||
for (size_t i = 0; i < num; i++) {
|
||||
if (layers[i] == type) count++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CacheLayerSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig, typename = void>
|
||||
struct CachePosSize {
|
||||
constexpr size_t operator()() const {
|
||||
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
|
||||
}
|
||||
};
|
||||
|
||||
struct OldConfigNoVit {
|
||||
struct VitConfig {
|
||||
// Some of these are needed to make the compiler happy when trying to
|
||||
// generate code that will actually never be used.
|
||||
using Weight = float;
|
||||
static constexpr int kLayers = 0;
|
||||
static constexpr std::array<LayerAttentionType, 0> kLayerConfig =
|
||||
OldFixedLayerConfig<0>(LayerAttentionType::kVit);
|
||||
static constexpr int kModelDim = 0;
|
||||
static constexpr int kFFHiddenDim = 0;
|
||||
static constexpr int kHeads = 1; // Avoid division by 0 in griffin gate_w.
|
||||
static constexpr int kKVHeads = 0;
|
||||
static constexpr int kQKVDim = 0;
|
||||
static constexpr int kSeqLen = 0;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
struct OldConfigNoSSM : OldConfigNoVit {
|
||||
static constexpr int kGriffinLayers = 0;
|
||||
|
||||
static constexpr int kConv1dWidth = 0;
|
||||
static constexpr bool kFFBiases = false;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = false;
|
||||
static constexpr bool kUseHalfRope = false;
|
||||
static constexpr bool kUseLocalAttention = false;
|
||||
static constexpr bool kInterleaveQKV = true;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
struct OldConfigBaseGemmaV1 : OldConfigNoSSM {
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
struct OldConfigBaseGemmaV2 : OldConfigNoSSM {
|
||||
static constexpr float kAttCap = 50.0f;
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::Scale;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_27B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
OldFixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
|
||||
static constexpr int kHeads = 32;
|
||||
static constexpr int kKVHeads = 16;
|
||||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale =
|
||||
QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_9B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
OldFixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 8;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma7B : public OldConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
OldFixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<28>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3072;
|
||||
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2B : public OldConfigBaseGemmaV1 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
OldFixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<18>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2048;
|
||||
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigPaliGemma_224 : public OldConfigGemma2B<TWeight> {
|
||||
// On the LM side, the vocab size is one difference to Gemma1-2B in the
|
||||
// architecture. PaliGemma adds 1024 <locNNNN> and 128 <segNNN> tokens.
|
||||
static constexpr int kVocabSize = 256000 + 1024 + 128; // = 257152
|
||||
|
||||
// Sub-config for the Vision-Transformer part.
|
||||
struct VitConfig : public OldConfigNoSSM {
|
||||
using Weight = TWeight;
|
||||
// The ViT parts. https://arxiv.org/abs/2305.13035
|
||||
// "SoViT-400m/14 [...] has a width of 1152, depth 27, and MLP dim 4304."
|
||||
static constexpr std::array<LayerAttentionType, 27> kLayerConfig =
|
||||
OldFixedLayerConfig<27>(LayerAttentionType::kVit);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kModelDim = 1152;
|
||||
static constexpr int kFFHiddenDim = 4304;
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 16; // standard MHA
|
||||
static constexpr int kQKVDim = 72;
|
||||
static constexpr int kSeqLen = 16 * 16; // 256
|
||||
static constexpr bool kFFBiases = true;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
static constexpr int kVocabSize = 0;
|
||||
// Dimensions related to image processing.
|
||||
static constexpr int kPatchWidth = 14;
|
||||
static constexpr int kImageSize = 224;
|
||||
// Necessary constant for the layer configuration.
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemma2_2B : public OldConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
|
||||
OldFixedLayerConfig<26>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
OldRepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2304;
|
||||
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
|
||||
static constexpr int kHeads = 8;
|
||||
static constexpr int kKVHeads = 4;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGemmaTiny : public OldConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 32;
|
||||
static constexpr int kVocabSize = 64;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
OldFixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<3>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kNumTensorScales = 4 * kLayers;
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 128;
|
||||
static constexpr int kFFHiddenDim = 256;
|
||||
static constexpr int kHeads = 4;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 16; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
static constexpr float kFinalCap = 30.0f;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct OldConfigGriffin2B : OldConfigNoVit {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
static constexpr int kSeqLen = 2048;
|
||||
static constexpr int kVocabSize = gcpp::kVocabSize;
|
||||
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGemma,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
};
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
OldFixedAttentionWindowSizes<26>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = OldNumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
static constexpr int kGriffinLayers = OldNumLayersOfTypeBefore(
|
||||
kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers);
|
||||
static constexpr int kModelDim = 2560;
|
||||
static constexpr int kFFHiddenDim = 7680;
|
||||
static constexpr int kHeads = 10;
|
||||
static constexpr int kKVHeads = 1;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
|
||||
// No SoftCap.
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
|
||||
// SSM config.
|
||||
static constexpr int kConv1dWidth = 4;
|
||||
static constexpr bool kFFBiases = true;
|
||||
static constexpr bool kSoftmaxAttnOutputBiases = true;
|
||||
static constexpr bool kUseHalfRope = true;
|
||||
static constexpr bool kUseLocalAttention = true;
|
||||
static constexpr bool kInterleaveQKV = false;
|
||||
static constexpr int kNumTensorScales = 140;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
void AssertMatch(const ModelConfig& config) {
|
||||
ASSERT_EQ(TConfig::kModelDim, config.model_dim);
|
||||
if constexpr (TConfig::VitConfig::kModelDim != 0) {
|
||||
ASSERT_EQ(TConfig::VitConfig::kModelDim, config.vit_model_dim);
|
||||
ASSERT_EQ(TConfig::VitConfig::kSeqLen, config.vit_seq_len);
|
||||
ASSERT_EQ(TConfig::VitConfig::kNumTensorScales, config.num_vit_scales);
|
||||
for (size_t i = 0; i < config.vit_layer_configs.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::VitConfig::kLayerConfig[i],
|
||||
config.vit_layer_configs[i].type);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
|
||||
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
|
||||
ASSERT_EQ(TConfig::kTopK, config.top_k);
|
||||
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
|
||||
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
|
||||
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);
|
||||
ASSERT_EQ(TConfig::kUseLocalAttention, config.use_local_attention);
|
||||
ASSERT_EQ(TConfig::kQueryScale, config.query_scale);
|
||||
ASSERT_EQ(TConfig::kGemmaLayers,
|
||||
config.NumLayersOfType(LayerAttentionType::kGemma));
|
||||
ASSERT_EQ(TConfig::kGriffinLayers,
|
||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock));
|
||||
for (size_t i = 0; i < config.layer_configs.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::kModelDim, config.layer_configs[i].model_dim);
|
||||
ASSERT_EQ(TConfig::kFFHiddenDim, config.layer_configs[i].ff_hidden_dim);
|
||||
ASSERT_EQ(TConfig::kHeads, config.layer_configs[i].heads);
|
||||
ASSERT_EQ(TConfig::kKVHeads, config.layer_configs[i].kv_heads);
|
||||
ASSERT_EQ(TConfig::kQKVDim, config.layer_configs[i].qkv_dim);
|
||||
ASSERT_EQ(TConfig::kConv1dWidth, config.layer_configs[i].conv1d_width);
|
||||
ASSERT_EQ(TConfig::kFFBiases, config.layer_configs[i].ff_biases);
|
||||
ASSERT_EQ(TConfig::kSoftmaxAttnOutputBiases,
|
||||
config.layer_configs[i].softmax_attn_output_biases);
|
||||
ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm);
|
||||
ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type);
|
||||
ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation);
|
||||
ASSERT_EQ(TConfig::kPostQK, config.layer_configs[i].post_qk);
|
||||
}
|
||||
|
||||
ASSERT_EQ(TConfig::kAttentionWindowSizes.size(),
|
||||
config.attention_window_sizes.size());
|
||||
for (size_t i = 0; i < config.attention_window_sizes.size(); ++i) {
|
||||
ASSERT_EQ(TConfig::kAttentionWindowSizes[i],
|
||||
config.attention_window_sizes[i]);
|
||||
}
|
||||
ASSERT_EQ(TConfig::kNumTensorScales, config.num_tensor_scales);
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2B) {
|
||||
AssertMatch<OldConfigGemma2B<float>>(ConfigFromModel(Model::GEMMA_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma7B) {
|
||||
AssertMatch<OldConfigGemma7B<float>>(ConfigFromModel(Model::GEMMA_7B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_2B) {
|
||||
AssertMatch<OldConfigGemma2_2B<float>>(ConfigFromModel(Model::GEMMA2_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_9B) {
|
||||
AssertMatch<OldConfigGemma2_9B<float>>(ConfigFromModel(Model::GEMMA2_9B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemma2_27B) {
|
||||
AssertMatch<OldConfigGemma2_27B<float>>(ConfigFromModel(Model::GEMMA2_27B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGriffin2B) {
|
||||
AssertMatch<OldConfigGriffin2B<float>>(ConfigFromModel(Model::GRIFFIN_2B));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigGemmaTiny) {
|
||||
AssertMatch<OldConfigGemmaTiny<float>>(ConfigFromModel(Model::GEMMA_TINY));
|
||||
}
|
||||
|
||||
TEST(ConfigsTest, OldConfigPaliGemma_224) {
|
||||
AssertMatch<OldConfigPaliGemma_224<float>>(
|
||||
ConfigFromModel(Model::PALIGEMMA_224));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
1027
gemma/gemma-inl.h
1027
gemma/gemma-inl.h
File diff suppressed because it is too large
Load Diff
|
|
@ -29,88 +29,90 @@
|
|||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, PerClusterPools& pools)
|
||||
: pools_(pools), tokenizer_(tokenizer_path), info_(info) {
|
||||
weights_u8_ =
|
||||
LoadCompressedWeights(weights, info.model, info.weight, pools_.Inner(0));
|
||||
model_.Load(weights, info.model, info.weight, pools_.Inner(0));
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||
PerClusterPools& pools)
|
||||
: pools_(pools), tokenizer_(std::move(tokenizer)), info_(info) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(info.model,
|
||||
pools_.Inner(0));
|
||||
model_.Allocate(info.model, info.weight, pools_.Inner(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
}
|
||||
|
||||
// There are >100 instantiations of the inference code. To reduce compile time,
|
||||
// There are >=3 types of the inference code. To reduce compile time,
|
||||
// we shard them across multiple translation units in instantiations/*.cc.
|
||||
// This declares the functions defined there. We use overloading because
|
||||
// explicit instantiations are still too slow to compile.
|
||||
#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \
|
||||
extern void GenerateSingle(CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
#define GEMMA_DECLARE(TWEIGHT) \
|
||||
extern void GenerateSingle(TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, \
|
||||
const PromptTokens& prompt, size_t pos, \
|
||||
size_t prefix_end, KVCache& kv_cache, \
|
||||
PerClusterPools& pools, TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, \
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, \
|
||||
PerClusterPools& pools, TimingInfo& timing_info); \
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
|
||||
const KVCaches& kv_caches, PerClusterPools& pools, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens( \
|
||||
CONFIGT<TWEIGHT>, const ByteStorageT& weights_u8, \
|
||||
TWEIGHT, const ModelWeightsStorage& model, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, PerClusterPools& pools);
|
||||
GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE);
|
||||
GEMMA_DECLARE(float)
|
||||
GEMMA_DECLARE(BF16)
|
||||
GEMMA_DECLARE(NuqStream)
|
||||
GEMMA_DECLARE(SfpStream)
|
||||
|
||||
// Adapters to select from the above overloads via CallForModelAndWeight.
|
||||
// Adapters to select from the above overloads via CallForModelWeight.
|
||||
template <class TConfig>
|
||||
struct GenerateSingleT {
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateSingle(TConfig(), weights_u8, runtime_config, prompt, pos,
|
||||
prefix_end, kv_cache, pools, timing_info);
|
||||
GenerateSingle(TConfig(), model, runtime_config, prompt, pos, prefix_end,
|
||||
kv_cache, pools, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateBatchT {
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const KVCaches& kv_caches, PerClusterPools& pools,
|
||||
TimingInfo& timing_info) const {
|
||||
GenerateBatch(TConfig(), weights_u8, runtime_config, queries_prompt,
|
||||
queries_pos, queries_prefix_end, kv_caches, pools,
|
||||
timing_info);
|
||||
GenerateBatch(TConfig(), model, runtime_config, queries_prompt, queries_pos,
|
||||
queries_prefix_end, kv_caches, pools, timing_info);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct GenerateImageTokensT {
|
||||
void operator()(const ByteStorageT& weights_u8,
|
||||
void operator()(const ModelWeightsStorage& model,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, PerClusterPools& pools) const {
|
||||
GenerateImageTokens(TConfig(), weights_u8, runtime_config, image,
|
||||
image_tokens, pools);
|
||||
GenerateImageTokens(TConfig(), model, runtime_config, image, image_tokens,
|
||||
pools);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -119,9 +121,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
|||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateSingleT>(
|
||||
info_.model, info_.weight, weights_u8_, runtime_config, prompt, pos,
|
||||
prefix_end, kv_cache, pools_, timing_info);
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
}
|
||||
|
|
@ -142,9 +143,9 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateBatchT>(
|
||||
info_.model, info_.weight, weights_u8_, runtime_config, queries_prompt,
|
||||
queries_pos, mutable_queries_prefix_end, kv_caches, pools_, timing_info);
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, pools_, timing_info);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
}
|
||||
|
|
@ -153,28 +154,25 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
|||
const Image& image, ImageTokens& image_tokens) {
|
||||
if (runtime_config.use_spinning) pools_.StartSpinning();
|
||||
|
||||
CallForModelAndWeight<GenerateImageTokensT>(info_.model, info_.weight,
|
||||
weights_u8_, runtime_config,
|
||||
image, image_tokens, pools_);
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, pools_);
|
||||
|
||||
if (runtime_config.use_spinning) pools_.StopSpinning();
|
||||
}
|
||||
|
||||
template <typename TConfig>
|
||||
struct GetModelConfig {
|
||||
ModelConfigInfo operator()() const {
|
||||
return ModelConfigInfo{
|
||||
.layers = TConfig::kLayers,
|
||||
.model_dim = TConfig::kModelDim,
|
||||
.heads = TConfig::kHeads,
|
||||
.kv_heads = TConfig::kKVHeads,
|
||||
.qkv_dim = TConfig::kQKVDim,
|
||||
};
|
||||
}
|
||||
};
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
||||
ModelConfigInfo Gemma::ModelConfig() const {
|
||||
return CallForModel<float, GetModelConfig>(info_.model);
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
||||
if (!weights_config.use_local_attention) {
|
||||
if (max_generated_tokens > weights_config.seq_len) {
|
||||
fprintf(stderr,
|
||||
"WARNING: max_generated_tokens %zu > kSeqLen %zu, truncating.\n",
|
||||
max_generated_tokens, weights_config.seq_len);
|
||||
max_generated_tokens = weights_config.seq_len;
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
|
|
@ -179,15 +180,6 @@ struct TimingInfo {
|
|||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
// ModelConfigInfo holds model configuration details: number of layers, etc.
|
||||
struct ModelConfigInfo {
|
||||
const int layers;
|
||||
const int model_dim;
|
||||
const int heads;
|
||||
const int kv_heads;
|
||||
const int qkv_dim;
|
||||
};
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||
|
|
@ -198,11 +190,11 @@ class Gemma {
|
|||
PerClusterPools& pools);
|
||||
~Gemma();
|
||||
|
||||
ModelConfigInfo ModelConfig() const;
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
const ModelInfo& Info() const { return info_; }
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||
ByteStorageT& MutableWeights() { return weights_u8_; }
|
||||
const ModelWeightsStorage& Weights() const { return model_; }
|
||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
|
|
@ -241,7 +233,7 @@ class Gemma {
|
|||
|
||||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
ByteStorageT weights_u8_;
|
||||
ModelWeightsStorage model_;
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
|
|
@ -251,6 +243,8 @@ class Gemma {
|
|||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt);
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, size_t prompt_size);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/27b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_27B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/27b_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_27B<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/27b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_27B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/2b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/7b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma7B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/7b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma7B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/9b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_9B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/9b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_9B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -14,8 +14,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/2b_f32.cc"
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2B<float>
|
||||
#define GEMMA_TYPE hwy::bfloat16_t
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -14,8 +14,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/7b_f32.cc"
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma7B<float>
|
||||
#define GEMMA_TYPE float
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gemma2_2b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gemma2_2b_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_2B<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gemma2_2b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gr2b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGriffin2B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gr2b_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGriffin2B<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/gr2b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGriffin2B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -14,8 +14,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/9b_f32.cc"
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/nuq.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2_9B<float>
|
||||
#define GEMMA_TYPE NuqStream
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/paligemma_224_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigPaliGemma_224<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/paligemma_224_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigPaliGemma_224<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/paligemma_224_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigPaliGemma_224<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -14,8 +14,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/2b_sfp.cc"
|
||||
#define HWY_TARGET_INCLUDE "gemma/instantiations/sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma2B<SfpStream>
|
||||
#define GEMMA_TYPE SfpStream
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/tiny_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemmaTiny<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/tiny_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemmaTiny<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/tiny_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemmaTiny<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
@ -15,32 +15,40 @@
|
|||
|
||||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "gemma/common.h" // CallForModel
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ZeroBytes
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
template <class TConfig>
|
||||
struct CreateKVCache {
|
||||
KVCache operator()(size_t prefill_tbatch_size) const {
|
||||
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
// prefill at a time.
|
||||
KVCache KVCache::Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache kv_cache = {};
|
||||
|
||||
const size_t size_cache_pos = CachePosSize<TConfig>()();
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
// Allocate more so that prefill can always access one batch, even if
|
||||
// near the end of the sequence.
|
||||
kv_cache.seq_len = TConfig::kSeqLen + prefill_tbatch_size;
|
||||
kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size;
|
||||
kv_cache.kv_cache =
|
||||
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
|
||||
}
|
||||
size_t num_griffin_layers = weights_config.NumLayersOfType(
|
||||
LayerAttentionType::kGriffinRecurrentBlock);
|
||||
|
||||
// TODO(patrickms): Add query batching support for Griffin.
|
||||
if (TConfig::kGriffinLayers) {
|
||||
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
if (num_griffin_layers > 0) {
|
||||
size_t conv1d_width = 0;
|
||||
for (const auto& layer_config : weights_config.layer_configs) {
|
||||
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
|
||||
}
|
||||
const size_t conv1d_cache_size =
|
||||
TConfig::kGriffinLayers * (kConv1dWidth == 0 ? 0 : kConv1dWidth - 1) *
|
||||
TConfig::kModelDim;
|
||||
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
|
||||
weights_config.model_dim;
|
||||
if (conv1d_cache_size != 0) {
|
||||
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.conv1d_cache.get(),
|
||||
|
|
@ -48,7 +56,7 @@ struct CreateKVCache {
|
|||
}
|
||||
|
||||
const size_t rglru_cache_size =
|
||||
TConfig::kGriffinLayers * TConfig::kModelDim;
|
||||
num_griffin_layers * weights_config.model_dim;
|
||||
if (rglru_cache_size != 0) {
|
||||
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
|
||||
hwy::ZeroBytes(kv_cache.rglru_cache.get(),
|
||||
|
|
@ -58,16 +66,5 @@ struct CreateKVCache {
|
|||
|
||||
return kv_cache;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
// prefill at a time.
|
||||
KVCache KVCache::Create(Model model_type, size_t prefill_tbatch_size) {
|
||||
// TWeight=float is a placeholder and unused because CreateKVCache does not
|
||||
// use TConfig::Weight.
|
||||
return CallForModel</*TWeight=*/float, CreateKVCache>(model_type,
|
||||
prefill_tbatch_size);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ struct KVCache {
|
|||
// kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
|
||||
|
||||
static KVCache Create(Model type, size_t prefill_tbatch_size);
|
||||
static KVCache Create(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
Gemma model = CreateGemma(loader, pools);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(model.Info().model, inference.prefill_tbatch_size);
|
||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
|
|
|
|||
192
gemma/weights.cc
192
gemma/weights.cc
|
|
@ -17,12 +17,14 @@
|
|||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "util/allocator.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -31,58 +33,128 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
template <class TConfig>
|
||||
struct LoadCompressedWeightsT {
|
||||
ByteStorageT operator()(const Path& weights, hwy::ThreadPool& pool) const {
|
||||
PROFILER_ZONE("Startup.LoadCompressedWeights");
|
||||
template <typename T>
|
||||
struct TensorLoader {
|
||||
void operator()(ModelWeightsPtrs<T>& weights, ForEachType fet,
|
||||
CacheLoader& loader) {
|
||||
weights.ForEachTensor(
|
||||
{&weights}, fet,
|
||||
[&loader](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
loader(name, tensors);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
BlobError ModelWeightsStorage::Load(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.LoadModelWeightsPtrs");
|
||||
if (!weights.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights.path.c_str());
|
||||
}
|
||||
|
||||
// Allocate compressed weights.
|
||||
using CWeights = CompressedWeights<TConfig>;
|
||||
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||
new (c_weights) CWeights(pool);
|
||||
|
||||
CacheLoader loader(weights);
|
||||
ForEachType fet =
|
||||
loader.HaveToc() ? ForEachType::kLoadWithToc : ForEachType::kLoadNoToc;
|
||||
CWeights::ForEachTensor(
|
||||
{c_weights}, fet,
|
||||
[&loader](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
loader(name, tensors);
|
||||
});
|
||||
std::vector<float> scales(TConfig::kNumTensorScales);
|
||||
if (TConfig::kNumTensorScales > 0) {
|
||||
if (fet == ForEachType::kLoadWithToc) {
|
||||
// TODO(rays): Load the config from the file.
|
||||
HWY_ABORT("TOC not supported yet.");
|
||||
} else {
|
||||
// No Toc-> no config.
|
||||
config_ = ConfigFromModel(model_type);
|
||||
config_.weight = weight_type;
|
||||
}
|
||||
CreateForType(weight_type, pool);
|
||||
CallForModelWeightT<TensorLoader>(fet, loader);
|
||||
std::vector<float> scales(config_.num_tensor_scales + config_.num_vit_scales);
|
||||
if (!scales.empty()) {
|
||||
loader.LoadScales(scales.data(), scales.size());
|
||||
}
|
||||
if (!loader.ReadAll(pool, c_weights->model_storage)) {
|
||||
HWY_ABORT("Failed to load model weights.");
|
||||
BlobError err = loader.ReadAll(pool, model_storage_);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "Failed to load model weights: %d\n", err);
|
||||
return err;
|
||||
}
|
||||
if (TConfig::kNumTensorScales > 0) {
|
||||
c_weights->GetOrApplyScales(scales);
|
||||
if (!scales.empty()) {
|
||||
GetOrApplyScales(scales);
|
||||
}
|
||||
{
|
||||
if (fet == ForEachType::kLoadNoToc) {
|
||||
PROFILER_ZONE("Startup.Reshape");
|
||||
c_weights->Reshape(pool);
|
||||
AllocAndCopyWithTranspose(pool);
|
||||
}
|
||||
return c_weights_u8;
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool) {
|
||||
return CallForModelAndWeight<LoadCompressedWeightsT>(model_type, weight_type,
|
||||
weights, pool);
|
||||
void ModelWeightsStorage::Allocate(const ModelConfig& config, Type weight_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Startup.AllocateModelWeightsPtrs");
|
||||
config_ = config;
|
||||
config_.weight = weight_type;
|
||||
CreateForType(weight_type, pool);
|
||||
if (float_weights_) float_weights_->Allocate(model_storage_, pool);
|
||||
if (bf16_weights_) bf16_weights_->Allocate(model_storage_, pool);
|
||||
if (sfp_weights_) sfp_weights_->Allocate(model_storage_, pool);
|
||||
if (nuq_weights_) nuq_weights_->Allocate(model_storage_, pool);
|
||||
}
|
||||
|
||||
class WeightInitializer {
|
||||
public:
|
||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
float* data = tensors[0]->data<float>();
|
||||
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
|
||||
data[i] = dist_(gen_);
|
||||
}
|
||||
tensors[0]->set_scale(1.0f);
|
||||
}
|
||||
|
||||
private:
|
||||
std::normal_distribution<float> dist_;
|
||||
std::mt19937& gen_;
|
||||
};
|
||||
|
||||
void ModelWeightsStorage::RandInit(std::mt19937& gen) {
|
||||
HWY_ASSERT(float_weights_);
|
||||
WeightInitializer init(gen);
|
||||
ModelWeightsPtrs<float>::ForEachTensor({float_weights_.get()},
|
||||
ForEachType::kLoadNoToc, init);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::ZeroInit() {
|
||||
if (float_weights_) float_weights_->ZeroInit();
|
||||
if (bf16_weights_) bf16_weights_->ZeroInit();
|
||||
if (sfp_weights_) sfp_weights_->ZeroInit();
|
||||
if (nuq_weights_) nuq_weights_->ZeroInit();
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::GetOrApplyScales(std::vector<float>& scales) {
|
||||
if (float_weights_) float_weights_->GetOrApplyScales(scales);
|
||||
if (bf16_weights_) bf16_weights_->GetOrApplyScales(scales);
|
||||
if (sfp_weights_) sfp_weights_->GetOrApplyScales(scales);
|
||||
if (nuq_weights_) nuq_weights_->GetOrApplyScales(scales);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::AllocAndCopyWithTranspose(hwy::ThreadPool& pool) {
|
||||
if (float_weights_)
|
||||
float_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
if (bf16_weights_)
|
||||
bf16_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
if (sfp_weights_)
|
||||
sfp_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
if (nuq_weights_)
|
||||
nuq_weights_->AllocAndCopyWithTranspose(pool, model_storage_);
|
||||
}
|
||||
|
||||
void ModelWeightsStorage::CopyWithTranspose(hwy::ThreadPool& pool) {
|
||||
if (float_weights_) float_weights_->CopyWithTranspose(pool);
|
||||
if (bf16_weights_) bf16_weights_->CopyWithTranspose(pool);
|
||||
if (sfp_weights_) sfp_weights_->CopyWithTranspose(pool);
|
||||
if (nuq_weights_) nuq_weights_->CopyWithTranspose(pool);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// For reasons unknown, this is shown as potentially unused in the IDE.
|
||||
void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) {
|
||||
|
||||
void LogVec(const char* name, const float* data, size_t len) {
|
||||
hwy::Stats stats;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
stats.Notify(data[i]);
|
||||
|
|
@ -91,36 +163,44 @@ void HWY_MAYBE_UNUSED LogVec(const char* name, const float* data, size_t len) {
|
|||
name, len, stats.Min(), stats.Mean(), stats.Max());
|
||||
}
|
||||
|
||||
class WeightLogger {
|
||||
public:
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
} // namespace
|
||||
|
||||
void ModelWeightsStorage::LogWeightStats() {
|
||||
size_t total_weights = 0;
|
||||
// Only for float weights.
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{float_weights_.get()}, ForEachType::kInitNoToc,
|
||||
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
const MatPtr& tensor = *tensors[0];
|
||||
if (tensor.scale() != 1.0f) {
|
||||
printf("[scale=%f] ", tensor.scale());
|
||||
}
|
||||
LogVec(name, tensor.data<float>(), tensor.NumElements());
|
||||
total_weights += tensor.NumElements();
|
||||
});
|
||||
printf("%-20s %12zu\n", "Total", total_weights);
|
||||
}
|
||||
size_t total_weights = 0;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct LogWeightStatsT {
|
||||
void operator()(const ByteStorageT& weights_u8) const {
|
||||
auto& weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
WeightLogger logger;
|
||||
CompressedWeights<TConfig>::ForEachTensor(
|
||||
{&weights}, ForEachType::kIgnoreNulls, logger);
|
||||
printf("%-20s %12zu\n", "Total", logger.total_weights);
|
||||
void ModelWeightsStorage::CreateForType(Type weight_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
float_weights_ = std::make_unique<ModelWeightsPtrs<float>>(config_, pool);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
bf16_weights_ = std::make_unique<ModelWeightsPtrs<BF16>>(config_, pool);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
sfp_weights_ =
|
||||
std::make_unique<ModelWeightsPtrs<SfpStream>>(config_, pool);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
nuq_weights_ =
|
||||
std::make_unique<ModelWeightsPtrs<NuqStream>>(config_, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LogWeightStats(gcpp::Model model_type, Type weight_type,
|
||||
const ByteStorageT& weights) {
|
||||
HWY_ASSERT(weight_type == Type::kF32);
|
||||
CallForModel<float, LogWeightStatsT>(model_type, weights);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
455
gemma/weights.h
455
gemma/weights.h
|
|
@ -18,9 +18,10 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <complex>
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
|
@ -29,7 +30,6 @@
|
|||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -53,57 +53,79 @@ enum class ForEachType {
|
|||
kInitNoToc,
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedLayer {
|
||||
template <class Weight>
|
||||
struct LayerWeightsPtrs {
|
||||
// Large data is constructed separately.
|
||||
CompressedLayer()
|
||||
: attn_vec_einsum_w("att_ein", kModelDim, kHeads * kQKVDim),
|
||||
qkv_einsum_w("qkv_ein", (kHeads + 2 * kKVHeads) * kQKVDim, kModelDim),
|
||||
qkv_einsum_w1("qkv1_w", kHeads * kQKVDim, kModelDim),
|
||||
qkv_einsum_w2("qkv2_w", 2 * kKVHeads * kQKVDim, kModelDim),
|
||||
attention_output_biases("attn_ob", 1, kAOBiasDim),
|
||||
griffin({.linear_x_w = {"gr_lin_x_w", kGriffinDim, kGriffinDim},
|
||||
.linear_x_biases = {"gr_lin_x_b", 1, kGriffinDim},
|
||||
.linear_y_w = {"gr_lin_y_w", kGriffinDim, kGriffinDim},
|
||||
.linear_y_biases = {"gr_lin_y_b", 1, kGriffinDim},
|
||||
.linear_out_w = {"gr_lin_out_w", kGriffinDim, kGriffinDim},
|
||||
.linear_out_biases = {"gr_lin_out_b", 1, kGriffinDim},
|
||||
.conv_w = {"gr_conv_w", kConv1dWidth, kGriffinDim},
|
||||
.conv_biases = {"gr_conv_b", 1, kGriffinDim},
|
||||
.gate_w = {"gr_gate_w", 2 * kGriffinDim, kGriffinDim / kHeads},
|
||||
.gate_biases = {"gr_gate_b", 1, kGriffinDim * 2},
|
||||
.a = {"gr_a", 1, kGriffinDim}}),
|
||||
explicit LayerWeightsPtrs(const LayerConfig& config)
|
||||
: attn_vec_einsum_w("att_ein", config.model_dim,
|
||||
config.heads * config.qkv_dim),
|
||||
qkv_einsum_w("qkv_ein",
|
||||
(config.heads + 2 * config.kv_heads) * config.qkv_dim,
|
||||
config.model_dim),
|
||||
qkv_einsum_w1("qkv1_w", config.heads * config.qkv_dim,
|
||||
config.model_dim),
|
||||
qkv_einsum_w2("qkv2_w", 2 * config.kv_heads * config.qkv_dim,
|
||||
config.model_dim),
|
||||
attention_output_biases(
|
||||
"attn_ob", 1,
|
||||
config.softmax_attn_output_biases ? config.model_dim : 0),
|
||||
griffin(
|
||||
{.linear_x_w = {"gr_lin_x_w", config.griffin_dim,
|
||||
config.griffin_dim},
|
||||
.linear_x_biases = {"gr_lin_x_b", 1, config.griffin_dim},
|
||||
.linear_y_w = {"gr_lin_y_w", config.griffin_dim,
|
||||
config.griffin_dim},
|
||||
.linear_y_biases = {"gr_lin_y_b", 1, config.griffin_dim},
|
||||
.linear_out_w = {"gr_lin_out_w", config.griffin_dim,
|
||||
config.griffin_dim},
|
||||
.linear_out_biases = {"gr_lin_out_b", 1, config.griffin_dim},
|
||||
.conv_w = {"gr_conv_w", config.conv1d_width, config.griffin_dim},
|
||||
.conv_biases = {"gr_conv_b", 1, config.griffin_dim},
|
||||
.gate_w = {"gr_gate_w", 2 * config.griffin_dim,
|
||||
config.griffin_dim / config.heads},
|
||||
.gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2},
|
||||
.a = {"gr_a", 1, config.griffin_dim}}),
|
||||
// MultiHeadDotProductAttention.
|
||||
vit({.attn_out_w = {"attn_out_w", kHeads * kQKVDim, kModelDim},
|
||||
.attn_out_b = {"attn_out_b", 1, kModelDim},
|
||||
.qkv_einsum_w = {"qkv_ein_w", (kHeads + 2 * kKVHeads) * kQKVDim,
|
||||
kModelDim},
|
||||
.qkv_einsum_b = {"qkv_ein_b", (kHeads + 2 * kKVHeads), kQKVDim},
|
||||
.linear_0_w = {"linear_0_w", kModelDim, kFFHiddenDim},
|
||||
.linear_0_b = {"linear_0_b", 1, kFFHiddenDim},
|
||||
.linear_1_w = {"linear_1_w", kFFHiddenDim, kModelDim},
|
||||
.linear_1_b = {"linear_1_b", 1, kModelDim},
|
||||
.layer_norm_0_bias = {"ln_0_bias", 1, kModelDim},
|
||||
.layer_norm_0_scale = {"ln_0_scale", 1, kModelDim},
|
||||
.layer_norm_1_bias = {"ln_1_bias", 1, kModelDim},
|
||||
.layer_norm_1_scale = {"ln_1_scale", 1, kModelDim}}),
|
||||
gating_einsum_w("gating_ein", 2 * kFFHiddenDim, kModelDim),
|
||||
gating_einsum_w1("gating1_w", kFFHiddenDim, kModelDim),
|
||||
gating_einsum_w2("gating2_w", kFFHiddenDim, kModelDim),
|
||||
linear_w("linear_w", kModelDim, kFFHiddenDim),
|
||||
pre_attention_norm_scale("pre_att_ns", 1, kModelDim),
|
||||
pre_ffw_norm_scale("pre_ff_ns", 1, kModelDim),
|
||||
vit({.attn_out_w = {"attn_out_w", config.heads * config.qkv_dim,
|
||||
config.model_dim},
|
||||
.attn_out_b = {"attn_out_b", 1, config.model_dim},
|
||||
.qkv_einsum_w = {"qkv_ein_w",
|
||||
(config.heads + 2 * config.kv_heads) *
|
||||
config.qkv_dim,
|
||||
config.model_dim},
|
||||
.qkv_einsum_b = {"qkv_ein_b", (config.heads + 2 * config.kv_heads),
|
||||
config.qkv_dim},
|
||||
.linear_0_w = {"linear_0_w", config.model_dim,
|
||||
config.ff_hidden_dim},
|
||||
.linear_0_b = {"linear_0_b", 1, config.ff_hidden_dim},
|
||||
.linear_1_w = {"linear_1_w", config.ff_hidden_dim,
|
||||
config.model_dim},
|
||||
.linear_1_b = {"linear_1_b", 1, config.model_dim},
|
||||
.layer_norm_0_bias = {"ln_0_bias", 1, config.model_dim},
|
||||
.layer_norm_0_scale = {"ln_0_scale", 1, config.model_dim},
|
||||
.layer_norm_1_bias = {"ln_1_bias", 1, config.model_dim},
|
||||
.layer_norm_1_scale = {"ln_1_scale", 1, config.model_dim}}),
|
||||
gating_einsum_w("gating_ein", 2 * config.ff_hidden_dim,
|
||||
config.model_dim),
|
||||
gating_einsum_w1("gating1_w", config.ff_hidden_dim, config.model_dim),
|
||||
gating_einsum_w2("gating2_w", config.ff_hidden_dim, config.model_dim),
|
||||
linear_w("linear_w", config.model_dim, config.ff_hidden_dim),
|
||||
pre_attention_norm_scale("pre_att_ns", 1, config.model_dim),
|
||||
pre_ffw_norm_scale("pre_ff_ns", 1, config.model_dim),
|
||||
post_attention_norm_scale(
|
||||
"post_att_ns", 1, kPostNorm == PostNormType::Scale ? kModelDim : 0),
|
||||
post_ffw_norm_scale("post_ff_ns", 1,
|
||||
kPostNorm == PostNormType::Scale ? kModelDim : 0),
|
||||
ffw_gating_biases("ffw_gat_b", 1, kFFBiases ? 2 * kFFHiddenDim : 0),
|
||||
ffw_output_biases("ffw_out_b", 1, kFFBiases ? kModelDim : 0),
|
||||
att_weights("att_w", kModelDim, kHeads * kQKVDim)
|
||||
{}
|
||||
~CompressedLayer() = default;
|
||||
"post_att_ns", 1,
|
||||
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
|
||||
post_ffw_norm_scale(
|
||||
"post_ff_ns", 1,
|
||||
config.post_norm == PostNormType::Scale ? config.model_dim : 0),
|
||||
ffw_gating_biases("ffw_gat_b", 1,
|
||||
config.ff_biases ? 2 * config.ff_hidden_dim : 0),
|
||||
ffw_output_biases("ffw_out_b", 1,
|
||||
config.ff_biases ? config.model_dim : 0),
|
||||
att_weights("att_w", config.model_dim, config.heads * config.qkv_dim),
|
||||
layer_config(config) {}
|
||||
~LayerWeightsPtrs() = default;
|
||||
|
||||
using Weight = typename TConfig::Weight;
|
||||
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
|
||||
// do not yet support smaller compressed types, or require at least bf16. When
|
||||
// weights are f32, we also want such tensors to be f32.
|
||||
|
|
@ -113,25 +135,6 @@ struct CompressedLayer {
|
|||
hwy::If<hwy::IsSame<Weight, double>(), double,
|
||||
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
||||
|
||||
static constexpr size_t kHeads = TConfig::kHeads;
|
||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||
static constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
static constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
|
||||
static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim;
|
||||
static constexpr size_t kQKVEinsumWSize =
|
||||
(kHeads + 2 * kKVHeads) * kQKVDim * kModelDim;
|
||||
static constexpr size_t kQKVEinsumBSize = (kHeads + 2 * kKVHeads) * kQKVDim;
|
||||
// 2x for (gelu gating vector, gated vector)
|
||||
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
|
||||
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TConfig::kFFBiases;
|
||||
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
|
||||
static constexpr size_t kAOBiasDim =
|
||||
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
|
||||
static constexpr size_t kGriffinDim =
|
||||
TConfig::kGriffinLayers > 0 ? kModelDim : 0;
|
||||
|
||||
template <class T>
|
||||
using ArrayT = MatPtrT<T>;
|
||||
|
||||
|
|
@ -195,28 +198,32 @@ struct CompressedLayer {
|
|||
// Reshaped attention; not loaded from disk via ForEachTensor.
|
||||
ArrayT<Weight> att_weights;
|
||||
|
||||
const LayerConfig& layer_config;
|
||||
|
||||
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
||||
// after loading weights via ForEachTensor.
|
||||
// TODO: update compression/convert_weights to bake this in.
|
||||
void Reshape(MatStorage& storage) {
|
||||
void Reshape(MatStorage* storage) {
|
||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kHeads = TConfig::kHeads;
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// Would have to implement a CompressTraits::Copy for NUQ.
|
||||
static_assert(!hwy::IsSame<Weight, NuqStream>());
|
||||
// TODO: implement a CompressTraits::Copy for NUQ.
|
||||
// static_assert(!hwy::IsSame<Weight, NuqStream>());
|
||||
|
||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||
storage.Allocate();
|
||||
att_weights.SetPtr(storage);
|
||||
for (size_t m = 0; m < kModelDim; ++m) {
|
||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * kHeads * kQKVDim;
|
||||
for (size_t h = 0; h < kHeads; ++h) {
|
||||
if (storage != nullptr) {
|
||||
storage->Allocate();
|
||||
att_weights.SetPtr(*storage);
|
||||
}
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
hwy::CopyBytes(
|
||||
attn_vec_einsum_w.data() + h * kModelDim * kQKVDim + m * kQKVDim,
|
||||
out_row + h * kQKVDim, kQKVDim * sizeof(Weight));
|
||||
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
||||
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
||||
}
|
||||
}
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
|
|
@ -235,11 +242,11 @@ struct CompressedLayer {
|
|||
}
|
||||
|
||||
template <class Func>
|
||||
static void ForEachTensor(const std::vector<CompressedLayer<TConfig>*>& ptrs,
|
||||
static void ForEachTensor(const std::vector<LayerWeightsPtrs<Weight>*>& ptrs,
|
||||
int layer_idx, ForEachType fet, Func func,
|
||||
char sep = ' ', int sep_index = -1) {
|
||||
MatPtr* tensors[ptrs.size()];
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
auto type = ptrs[0]->layer_config.type;
|
||||
if (type == LayerAttentionType::kVit) {
|
||||
// MHA.
|
||||
GEMMA_CALL_FUNC(vit.attn_out_w);
|
||||
|
|
@ -296,17 +303,17 @@ struct CompressedLayer {
|
|||
GEMMA_CALL_FUNC(pre_attention_norm_scale);
|
||||
GEMMA_CALL_FUNC(pre_ffw_norm_scale);
|
||||
|
||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||
if (ptrs[0]->layer_config.post_norm == PostNormType::Scale) {
|
||||
GEMMA_CALL_FUNC(post_attention_norm_scale);
|
||||
GEMMA_CALL_FUNC(post_ffw_norm_scale);
|
||||
}
|
||||
|
||||
if (TConfig::kFFBiases) {
|
||||
if (ptrs[0]->layer_config.ff_biases) {
|
||||
GEMMA_CALL_FUNC(ffw_gating_biases);
|
||||
GEMMA_CALL_FUNC(ffw_output_biases);
|
||||
}
|
||||
|
||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||
if (ptrs[0]->layer_config.softmax_attn_output_biases &&
|
||||
type == LayerAttentionType::kGemma) {
|
||||
GEMMA_CALL_FUNC(attention_output_biases);
|
||||
}
|
||||
|
|
@ -322,47 +329,45 @@ struct CompressedLayer {
|
|||
|
||||
// Allocates memory for all the tensors in the layer.
|
||||
// Note that this is slow and only used for a stand-alone layer.
|
||||
void Allocate() {
|
||||
layer_storage.clear();
|
||||
ForEachTensor({this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
|
||||
[this](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
this->layer_storage.emplace_back(*tensors[0]);
|
||||
void Allocate(std::vector<MatStorage>& layer_storage) {
|
||||
ForEachTensor(
|
||||
{this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
|
||||
[&layer_storage](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
layer_storage.emplace_back(*tensors[0]);
|
||||
layer_storage.back().Allocate();
|
||||
tensors[0]->SetPtr(layer_storage.back());
|
||||
});
|
||||
}
|
||||
|
||||
// Storage for all the matrices and vectors. Only used for a stand-alone
|
||||
// layer. For a model, the CompressedWeights::model_storage is used instead.
|
||||
std::vector<MatStorage> layer_storage;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedWeights {
|
||||
explicit CompressedWeights(hwy::ThreadPool& pool)
|
||||
: embedder_input_embedding("c_embedding", TConfig::kVocabSize,
|
||||
TConfig::kModelDim),
|
||||
final_norm_scale("c_final_norm", 1, TConfig::kModelDim),
|
||||
vit_encoder_norm_bias("enc_norm_bias", 1,
|
||||
TConfig::VitConfig::kModelDim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1,
|
||||
TConfig::VitConfig::kModelDim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1,
|
||||
TConfig::VitConfig::kModelDim),
|
||||
template <class Weight>
|
||||
struct ModelWeightsPtrs {
|
||||
ModelWeightsPtrs(const ModelConfig& config, hwy::ThreadPool& pool)
|
||||
: embedder_input_embedding("c_embedding", config.vocab_size,
|
||||
config.model_dim),
|
||||
final_norm_scale("c_final_norm", 1, config.model_dim),
|
||||
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3,
|
||||
TConfig::VitConfig::kModelDim),
|
||||
vit_img_pos_embedding("img_pos_emb", 256,
|
||||
TConfig::VitConfig::kModelDim),
|
||||
vit_img_head_bias("img_head_bias", 1, TConfig::kModelDim),
|
||||
vit_img_head_kernel("img_head_kernel", TConfig::VitConfig::kModelDim,
|
||||
TConfig::kModelDim),
|
||||
scale_names({"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w",
|
||||
"gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}) {}
|
||||
config.vit_model_dim),
|
||||
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
|
||||
vit_img_head_bias("img_head_bias", 1, config.model_dim),
|
||||
vit_img_head_kernel("img_head_kernel", config.vit_model_dim,
|
||||
config.model_dim),
|
||||
scale_names(config.scale_names),
|
||||
weights_config(config) {
|
||||
c_layers.reserve(config.layer_configs.size());
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
c_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
|
||||
}
|
||||
for (const auto& layer_config : config.vit_layer_configs) {
|
||||
vit_layers.push_back(LayerWeightsPtrs<Weight>(layer_config));
|
||||
}
|
||||
}
|
||||
|
||||
~CompressedWeights() = default;
|
||||
|
||||
using Weight = typename TConfig::Weight;
|
||||
using WeightF32OrBF16 = typename CompressedLayer<TConfig>::WeightF32OrBF16;
|
||||
~ModelWeightsPtrs() = default;
|
||||
using WeightF32OrBF16 = typename LayerWeightsPtrs<Weight>::WeightF32OrBF16;
|
||||
using WeightF32OrInputT = hwy::If<hwy::IsSame<WeightF32OrBF16, BF16>(),
|
||||
EmbedderInputT, WeightF32OrBF16>;
|
||||
|
||||
|
|
@ -380,49 +385,73 @@ struct CompressedWeights {
|
|||
MatPtrT<float> vit_img_head_bias;
|
||||
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
||||
|
||||
// Storage for all the matrices and vectors.
|
||||
std::vector<MatStorage> model_storage;
|
||||
std::unordered_set<std::string> scale_names;
|
||||
|
||||
CompressedLayer<TConfig> c_layers[TConfig::kLayers];
|
||||
CompressedLayer<typename TConfig::VitConfig>
|
||||
vit_layers[TConfig::VitConfig::kLayers];
|
||||
const ModelConfig& weights_config;
|
||||
|
||||
// Called by weights.cc after ForEachTensor.
|
||||
void Reshape(hwy::ThreadPool& pool) {
|
||||
std::vector<LayerWeightsPtrs<Weight>> c_layers;
|
||||
std::vector<LayerWeightsPtrs<Weight>> vit_layers;
|
||||
|
||||
// Called by weights.cc after Loading, before att_w has been allocated.
|
||||
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool,
|
||||
std::vector<MatStorage>& model_storage) {
|
||||
size_t storage_index = model_storage.size();
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
model_storage.emplace_back(GetLayer(layer)->att_weights);
|
||||
for (auto& layer : c_layers) {
|
||||
model_storage.emplace_back(layer.att_weights);
|
||||
}
|
||||
pool.Run(0, TConfig::kLayers,
|
||||
[this, storage_index](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Reshape(model_storage[storage_index + layer]);
|
||||
pool.Run(0, c_layers.size(),
|
||||
[this, &model_storage, storage_index](uint64_t layer,
|
||||
size_t /*thread*/) {
|
||||
GetLayer(layer)->Reshape(&model_storage[storage_index + layer]);
|
||||
});
|
||||
}
|
||||
// For when the storage has already been allocated.
|
||||
void CopyWithTranspose(hwy::ThreadPool& pool) {
|
||||
pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
|
||||
GetLayer(layer)->Reshape(nullptr);
|
||||
});
|
||||
}
|
||||
|
||||
void ZeroInit() {
|
||||
embedder_input_embedding.ZeroInit();
|
||||
final_norm_scale.ZeroInit();
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
for (size_t i = 0; i < c_layers.size(); ++i) {
|
||||
c_layers[i].ZeroInit(i);
|
||||
}
|
||||
}
|
||||
|
||||
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
|
||||
const LayerWeightsPtrs<Weight>* GetLayer(size_t layer) const {
|
||||
return &c_layers[layer];
|
||||
}
|
||||
CompressedLayer<TConfig>* GetLayer(size_t layer) { return &c_layers[layer]; }
|
||||
const CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(
|
||||
size_t layer) const {
|
||||
LayerWeightsPtrs<Weight>* GetLayer(size_t layer) { return &c_layers[layer]; }
|
||||
const LayerWeightsPtrs<Weight>* GetVitLayer(size_t layer) const {
|
||||
return &vit_layers[layer];
|
||||
}
|
||||
CompressedLayer<typename TConfig::VitConfig>* GetVitLayer(size_t layer) {
|
||||
LayerWeightsPtrs<Weight>* GetVitLayer(size_t layer) {
|
||||
return &vit_layers[layer];
|
||||
}
|
||||
|
||||
void Allocate(std::vector<MatStorage>& model_storage, hwy::ThreadPool& pool) {
|
||||
std::vector<MatPtr*> model_toc;
|
||||
ForEachTensor(
|
||||
{this}, ForEachType::kInitNoToc,
|
||||
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
model_toc.push_back(tensors[0]);
|
||||
model_storage.emplace_back(*tensors[0]);
|
||||
});
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_toc.size(),
|
||||
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
|
||||
// model_storage may have had content before we started.
|
||||
size_t idx = task + model_storage.size() - model_toc.size();
|
||||
model_storage[idx].Allocate();
|
||||
model_toc[task]->SetPtr(model_storage[idx]);
|
||||
});
|
||||
}
|
||||
|
||||
// Copies the data from other to *this.
|
||||
void CopyFrom(const CompressedWeights<TConfig>& other) {
|
||||
ForEachTensor({this, const_cast<CompressedWeights<TConfig>*>(&other)},
|
||||
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
|
||||
ForEachTensor({this, const_cast<ModelWeightsPtrs<Weight>*>(&other)},
|
||||
ForEachType::kIgnoreNulls,
|
||||
[](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(),
|
||||
|
|
@ -448,16 +477,14 @@ struct CompressedWeights {
|
|||
++scale_pos;
|
||||
}
|
||||
});
|
||||
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
||||
HWY_ASSERT(scale_pos == weights_config.num_tensor_scales);
|
||||
}
|
||||
|
||||
template <class Func>
|
||||
static void ForEachTensor(
|
||||
const std::vector<CompressedWeights<TConfig>*>& ptrs, ForEachType fet,
|
||||
Func func) {
|
||||
std::vector<CompressedLayer<TConfig>*> layers(ptrs.size());
|
||||
std::vector<CompressedLayer<typename TConfig::VitConfig>*> vit_layers(
|
||||
ptrs.size());
|
||||
static void ForEachTensor(const std::vector<ModelWeightsPtrs<Weight>*>& ptrs,
|
||||
ForEachType fet, Func func) {
|
||||
std::vector<LayerWeightsPtrs<Weight>*> layers(ptrs.size());
|
||||
std::vector<LayerWeightsPtrs<Weight>*> vit_layers(ptrs.size());
|
||||
MatPtr* tensors[ptrs.size()];
|
||||
// Variables used by GEMMA_CALL_FUNC.
|
||||
int layer_idx = -1;
|
||||
|
|
@ -465,7 +492,7 @@ struct CompressedWeights {
|
|||
int sep_index = -1;
|
||||
GEMMA_CALL_FUNC(embedder_input_embedding);
|
||||
GEMMA_CALL_FUNC(final_norm_scale);
|
||||
if constexpr (TConfig::VitConfig::kLayers > 0) {
|
||||
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
|
||||
// Vit parts.
|
||||
GEMMA_CALL_FUNC(vit_encoder_norm_bias);
|
||||
GEMMA_CALL_FUNC(vit_encoder_norm_scale);
|
||||
|
|
@ -476,90 +503,108 @@ struct CompressedWeights {
|
|||
GEMMA_CALL_FUNC(vit_img_head_kernel);
|
||||
}
|
||||
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
|
||||
for (int i = 0; i < ptrs.size(); ++i) {
|
||||
layers[i] = ptrs[i]->GetLayer(layer_idx);
|
||||
}
|
||||
CompressedLayer<TConfig>::ForEachTensor(layers, layer_idx, fet, func);
|
||||
LayerWeightsPtrs<Weight>::ForEachTensor(layers, layer_idx, fet, func);
|
||||
}
|
||||
|
||||
// Vit layers. Not supported for compress_weights.
|
||||
if constexpr (TConfig::VitConfig::kLayers > 0) {
|
||||
for (int layer_idx = 0; layer_idx < TConfig::VitConfig::kLayers;
|
||||
if (ptrs[0]->weights_config.vit_layer_configs.size() > 0) {
|
||||
for (int layer_idx = 0; layer_idx < ptrs[0]->vit_layers.size();
|
||||
++layer_idx) {
|
||||
auto type = TConfig::VitConfig::kLayerConfig[layer_idx];
|
||||
auto type = ptrs[0]->vit_layers[layer_idx].layer_config.type;
|
||||
HWY_ASSERT(type == LayerAttentionType::kVit);
|
||||
for (int i = 0; i < ptrs.size(); ++i) {
|
||||
vit_layers[i] = ptrs[i]->GetVitLayer(layer_idx);
|
||||
}
|
||||
CompressedLayer<typename TConfig::VitConfig>::ForEachTensor(
|
||||
vit_layers, layer_idx, fet, func);
|
||||
LayerWeightsPtrs<Weight>::ForEachTensor(vit_layers, layer_idx, fet,
|
||||
func);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
#undef GEMMA_CALL_FUNC
|
||||
|
||||
// Pair of configs for the compressed and uncompressed weights.
|
||||
template <class CConfig, class UCConfig>
|
||||
struct ConfigPair {
|
||||
using uc = UCConfig;
|
||||
using c = CConfig;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Interface
|
||||
|
||||
template <typename TConfig>
|
||||
struct AllocateCompressedWeights {
|
||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||
using TWeights = CompressedWeights<TConfig>;
|
||||
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||
new (weights) TWeights(pool);
|
||||
std::vector<MatPtr*> model_toc;
|
||||
auto& model_storage = weights->model_storage;
|
||||
TWeights::ForEachTensor(
|
||||
{weights}, ForEachType::kInitNoToc,
|
||||
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
model_toc.push_back(tensors[0]);
|
||||
model_storage.emplace_back(*tensors[0]);
|
||||
});
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_storage.size(),
|
||||
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
|
||||
model_storage[task].Allocate();
|
||||
model_toc[task]->SetPtr(model_storage[task]);
|
||||
});
|
||||
return weights_u8;
|
||||
class ModelWeightsStorage {
|
||||
public:
|
||||
ModelWeightsStorage() = default;
|
||||
~ModelWeightsStorage() = default;
|
||||
|
||||
BlobError Load(const Path& weights, Model model_type, Type weight_type,
|
||||
hwy::ThreadPool& pool);
|
||||
void Allocate(Model model_type, Type weight_type, hwy::ThreadPool& pool) {
|
||||
Allocate(ConfigFromModel(model_type), weight_type, pool);
|
||||
}
|
||||
};
|
||||
void Allocate(const ModelConfig& config, Type weight_type,
|
||||
hwy::ThreadPool& pool);
|
||||
void RandInit(std::mt19937& gen);
|
||||
void ZeroInit();
|
||||
void GetOrApplyScales(std::vector<float>& scales);
|
||||
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool);
|
||||
void CopyWithTranspose(hwy::ThreadPool& pool);
|
||||
void LogWeightStats();
|
||||
const ModelConfig& Config() const { return config_; }
|
||||
|
||||
template <typename TConfig>
|
||||
struct ZeroInitCompressedWeights {
|
||||
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
||||
CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
weights.ZeroInit();
|
||||
template <typename T>
|
||||
ModelWeightsPtrs<T>* GetWeightsOfType() const {
|
||||
if constexpr (IsSfpStream<T>()) {
|
||||
return sfp_weights_.get();
|
||||
} else if constexpr (IsF32<T>()) {
|
||||
return float_weights_.get();
|
||||
} else if constexpr (IsBF16<T>()) {
|
||||
return bf16_weights_.get();
|
||||
} else if constexpr (IsNuqStream<T>()) {
|
||||
return nuq_weights_.get();
|
||||
} else {
|
||||
return HWY_ABORT("Unsupported type.");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct ReshapeCompressedWeights {
|
||||
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
||||
CompressedWeights<TConfig>& weights =
|
||||
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
|
||||
weights.Reshape(pool);
|
||||
}
|
||||
|
||||
template <template <typename T> class FuncT, typename... TArgs>
|
||||
decltype(auto) CallForModelWeightT(TArgs&&... args) {
|
||||
if (HWY_LIKELY(sfp_weights_))
|
||||
return FuncT<SfpStream>()(*sfp_weights_, std::forward<TArgs>(args)...);
|
||||
if (bf16_weights_)
|
||||
return FuncT<BF16>()(*bf16_weights_, std::forward<TArgs>(args)...);
|
||||
if (nuq_weights_)
|
||||
return FuncT<NuqStream>()(*nuq_weights_, std::forward<TArgs>(args)...);
|
||||
if (float_weights_)
|
||||
return FuncT<float>()(*float_weights_, std::forward<TArgs>(args)...);
|
||||
return HWY_ABORT("No weights loaded.");
|
||||
}
|
||||
|
||||
template <template <typename T> class FuncT, typename... TArgs>
|
||||
decltype(auto) CallForModelWeight(TArgs&&... args) {
|
||||
if (HWY_LIKELY(sfp_weights_))
|
||||
return FuncT<SfpStream>()(*this, std::forward<TArgs>(args)...);
|
||||
if (bf16_weights_)
|
||||
return FuncT<BF16>()(*this, std::forward<TArgs>(args)...);
|
||||
if (nuq_weights_)
|
||||
return FuncT<NuqStream>()(*this, std::forward<TArgs>(args)...);
|
||||
if (float_weights_)
|
||||
return FuncT<float>()(*this, std::forward<TArgs>(args)...);
|
||||
return HWY_ABORT("No weights loaded.");
|
||||
}
|
||||
|
||||
private:
|
||||
void CreateForType(Type weight_type, hwy::ThreadPool& pool);
|
||||
|
||||
ModelConfig config_;
|
||||
// To eliminate type templates, we hold a pointer to one of each weight type
|
||||
// and dispatch to whichever is non-null.
|
||||
std::unique_ptr<ModelWeightsPtrs<float>> float_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
||||
// Storage for all the matrices and vectors.
|
||||
std::vector<MatStorage> model_storage_;
|
||||
};
|
||||
|
||||
// TODO: also add RandInitCompressedWeights
|
||||
|
||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model_type,
|
||||
Type weight_type, hwy::ThreadPool& pool);
|
||||
|
||||
void LogWeightStats(Model model, Type weight_type, const ByteStorageT& weights);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ void TestMatVecAdd() {
|
|||
FloatPtr expected_out = SimpleMatVecAdd(*mat, vec, add);
|
||||
FloatPtr actual_out = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add && expected_out && actual_out);
|
||||
MatVecAdd<kOuter, kInner>(*mat, 0, vec.get(), add.get(), actual_out.get(),
|
||||
MatVecAdd(*mat, 0, kOuter, kInner, vec.get(), add.get(), actual_out.get(),
|
||||
pool);
|
||||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
|
@ -135,9 +135,8 @@ void TestTwoMatVecAdd() {
|
|||
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||
expected_out1 && actual_out1);
|
||||
TwoMatVecAdd<kOuter, kInner>(*mat0, *mat1, 0, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(), actual_out1.get(),
|
||||
pool);
|
||||
TwoMatVecAdd(*mat0, *mat1, 0, kOuter, kInner, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(), actual_out1.get(), pool);
|
||||
AssertClose<kOuter>(actual_out0, expected_out0);
|
||||
AssertClose<kOuter>(actual_out1, expected_out1);
|
||||
}
|
||||
|
|
@ -156,9 +155,8 @@ void TestTwoOfsMatVecAddLoop() {
|
|||
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||
expected_out1 && actual_out1);
|
||||
TwoOfsMatVecAddLoop<kOuter, kInner>(*mat, 0, 0, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(),
|
||||
actual_out1.get());
|
||||
TwoOfsMatVecAddLoop(*mat, 0, 0, kOuter, kInner, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(), actual_out1.get());
|
||||
AssertClose<kOuter>(actual_out0, expected_out0);
|
||||
AssertClose<kOuter>(actual_out1, expected_out1);
|
||||
}
|
||||
|
|
|
|||
116
ops/matvec-inl.h
116
ops/matvec-inl.h
|
|
@ -47,10 +47,10 @@ namespace hn = hwy::HWY_NAMESPACE;
|
|||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||
// always with addition.
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
|
||||
typename AddT>
|
||||
template <typename ArrayT, typename VecT, typename AddT>
|
||||
HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||
const size_t mat_ofs1,
|
||||
const size_t mat_ofs1, const size_t outer,
|
||||
const size_t inner,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add0,
|
||||
const AddT* HWY_RESTRICT add1,
|
||||
|
|
@ -58,13 +58,13 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
float* HWY_RESTRICT out1) {
|
||||
PROFILER_ZONE("TwoOfsMatVecAddLoop");
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
|
||||
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
|
||||
for (size_t idx_row = 0; idx_row < outer; ++idx_row) {
|
||||
const size_t row_ofs0 = mat_ofs0 + (idx_row)*inner;
|
||||
const size_t row_ofs1 = mat_ofs1 + (idx_row)*inner;
|
||||
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
|
||||
Dot(mat, row_ofs0, vec_aligned, kInner);
|
||||
Dot(mat, row_ofs0, vec_aligned, inner);
|
||||
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
|
||||
Dot(mat, row_ofs1, vec_aligned, kInner);
|
||||
Dot(mat, row_ofs1, vec_aligned, inner);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,6 +84,14 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
|||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
HWY_INLINE size_t RowsPerStrip(const size_t outer) {
|
||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||
// vector; prefer a power of two for faster division.
|
||||
constexpr size_t kLanes = hn::ScalableTag<float>().MaxLanes();
|
||||
return outer < 128 ? kLanes
|
||||
: HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(outer / 128));
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||
|
|
@ -161,63 +169,63 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
|||
|
||||
// Stores dot products of rows with `vec_aligned` + add the values from `add`
|
||||
// (if kAdd), then stores them to `out`.
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||
typename VecT, typename AddT>
|
||||
template <bool kAdd, typename ArrayT, typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
|
||||
const size_t outer, const size_t inner,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("MatVecAdd");
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
const size_t rows_per_strip = RowsPerStrip(outer);
|
||||
const size_t num_strips = outer / rows_per_strip;
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("MatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned, add,
|
||||
const size_t r0 = strip * rows_per_strip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, inner, r0,
|
||||
rows_per_strip, vec_aligned, add,
|
||||
out + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||
if (r0 < kOuter) {
|
||||
const size_t r0 = num_strips * rows_per_strip;
|
||||
if (r0 < outer) {
|
||||
PROFILER_ZONE("MatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, kInner, r0,
|
||||
num_rows, vec_aligned, add, out + r0);
|
||||
const size_t num_rows = outer - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat, mat_ofs, inner, r0, num_rows,
|
||||
vec_aligned, add, out + r0);
|
||||
}
|
||||
}
|
||||
|
||||
// With addition
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT,
|
||||
typename AddT>
|
||||
template <typename ArrayT, typename VecT, typename AddT>
|
||||
HWY_INLINE void MatVecAdd(const ArrayT& mat, const size_t mat_ofs,
|
||||
const size_t outer, const size_t inner,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
const AddT* HWY_RESTRICT const add,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
return MatVecT</*kAdd=*/true, kOuter, kInner>(mat, mat_ofs, vec_aligned, add,
|
||||
return MatVecT</*kAdd=*/true>(mat, mat_ofs, outer, inner, vec_aligned, add,
|
||||
out, pool);
|
||||
}
|
||||
|
||||
// Without addition
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT, typename VecT>
|
||||
template <typename ArrayT, typename VecT>
|
||||
HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
||||
const size_t outer, const size_t inner,
|
||||
const VecT* HWY_RESTRICT const vec_aligned,
|
||||
float* HWY_RESTRICT out, hwy::ThreadPool& pool) {
|
||||
MatVecT</*kAdd=*/false, kOuter, kInner>(mat, mat_ofs, vec_aligned,
|
||||
/*add=*/static_cast<VecT*>(nullptr),
|
||||
out, pool);
|
||||
MatVecT</*kAdd=*/false>(mat, mat_ofs, outer, inner, vec_aligned,
|
||||
/*add=*/static_cast<VecT*>(nullptr), out, pool);
|
||||
}
|
||||
|
||||
// Two matrices, same vector
|
||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT1,
|
||||
typename ArrayT2, typename VecT, typename AddT>
|
||||
template <bool kAdd, typename ArrayT1, typename ArrayT2, typename VecT,
|
||||
typename AddT>
|
||||
HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
|
||||
const size_t mat_ofs,
|
||||
const size_t mat_ofs, size_t outer, size_t inner,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const AddT* HWY_RESTRICT add0,
|
||||
const AddT* HWY_RESTRICT add1,
|
||||
|
|
@ -226,56 +234,56 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT1& mat0, const ArrayT2& mat1,
|
|||
PROFILER_ZONE("TwoMatVecAdd");
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
constexpr size_t kRowsPerStrip = RowsPerStrip<kOuter>();
|
||||
constexpr size_t kNumStrips = kOuter / kRowsPerStrip;
|
||||
const size_t rows_per_strip = RowsPerStrip(outer);
|
||||
const size_t num_strips = outer / rows_per_strip;
|
||||
|
||||
// For each entire strip.
|
||||
pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
pool.Run(0, num_strips, [&](const uint64_t strip, size_t thread) HWY_ATTR {
|
||||
PROFILER_ZONE("TwoMatVec.lambda");
|
||||
const size_t r0 = strip * kRowsPerStrip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned, add0,
|
||||
const size_t r0 = strip * rows_per_strip;
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat0, mat_ofs, inner, r0,
|
||||
rows_per_strip, vec_aligned, add0,
|
||||
out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, kInner, r0,
|
||||
kRowsPerStrip, vec_aligned, add1,
|
||||
detail::FullDotProductsForStrip<kAdd>(df, mat1, mat_ofs, inner, r0,
|
||||
rows_per_strip, vec_aligned, add1,
|
||||
out1 + r0);
|
||||
});
|
||||
|
||||
// Remaining rows
|
||||
const size_t r0 = kNumStrips * kRowsPerStrip;
|
||||
if (r0 < kOuter) {
|
||||
const size_t r0 = num_strips * rows_per_strip;
|
||||
if (r0 < outer) {
|
||||
PROFILER_ZONE("TwoMatVec remainder");
|
||||
const size_t num_rows = kOuter - r0;
|
||||
const size_t num_rows = outer - r0;
|
||||
detail::FullDotProductsForStrip<kAdd>(
|
||||
df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
||||
df, mat0, mat_ofs, inner, r0, num_rows, vec_aligned, add0, out0 + r0);
|
||||
detail::FullDotProductsForStrip<kAdd>(
|
||||
df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
||||
df, mat1, mat_ofs, inner, r0, num_rows, vec_aligned, add1, out1 + r0);
|
||||
}
|
||||
}
|
||||
|
||||
// With addition
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT1, typename ArrayT2,
|
||||
typename VecT, typename AddT>
|
||||
template <typename ArrayT1, typename ArrayT2, typename VecT, typename AddT>
|
||||
HWY_NOINLINE void TwoMatVecAdd(
|
||||
const ArrayT1& mat0, const ArrayT2& mat1, const size_t mat_ofs,
|
||||
const size_t outer, const size_t inner,
|
||||
const VecT* HWY_RESTRICT vec_aligned, const AddT* HWY_RESTRICT add0,
|
||||
const AddT* HWY_RESTRICT add1, float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1, hwy::ThreadPool& pool) {
|
||||
return TwoMatVecT</*kAdd=*/true, kOuter, kInner>(
|
||||
mat0, mat1, mat_ofs, vec_aligned, add0, add1, out0, out1, pool);
|
||||
return TwoMatVecT</*kAdd=*/true>(mat0, mat1, mat_ofs, outer, inner,
|
||||
vec_aligned, add0, add1, out0, out1, pool);
|
||||
}
|
||||
|
||||
// Without addition
|
||||
template <size_t kOuter, size_t kInner, typename ArrayT1, typename ArrayT2,
|
||||
typename VecT>
|
||||
template <typename ArrayT1, typename ArrayT2, typename VecT>
|
||||
HWY_NOINLINE void TwoMatVec(const ArrayT1& mat0, const ArrayT2& mat1,
|
||||
const size_t mat_ofs,
|
||||
const size_t mat_ofs, const size_t outer,
|
||||
const size_t inner,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
float* HWY_RESTRICT out0, float* HWY_RESTRICT out1,
|
||||
hwy::ThreadPool& pool) {
|
||||
TwoMatVecT</*kAdd=*/false, kOuter, kInner, ArrayT1, ArrayT2, VecT, VecT>(
|
||||
mat0, mat1, mat_ofs, vec_aligned, /*add0=*/nullptr, /*add1=*/nullptr,
|
||||
out0, out1, pool);
|
||||
TwoMatVecT</*kAdd=*/false, ArrayT1, ArrayT2, VecT, VecT>(
|
||||
mat0, mat1, mat_ofs, outer, inner, vec_aligned, /*add0=*/nullptr,
|
||||
/*add1=*/nullptr, out0, out1, pool);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -21,11 +21,11 @@
|
|||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
|
|
@ -673,9 +673,8 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
|
|||
return max_index;
|
||||
}
|
||||
|
||||
template <size_t k>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||
HWY_INLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> create_distribution(
|
||||
std::vector<float>& top_k, float temperature) {
|
||||
HWY_ASSERT(temperature >= 0.0f);
|
||||
if (temperature == 0.0f) {
|
||||
// Temperature == 0 is a special case which always returns the argmax (0).
|
||||
|
|
@ -696,16 +695,16 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
|
|||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||
}
|
||||
|
||||
template <size_t k, typename TAcceptToken>
|
||||
template <typename TAcceptToken>
|
||||
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
||||
const float* HWY_RESTRICT probabilities, size_t k, size_t vocab_size,
|
||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||
static_assert(k != 0, "");
|
||||
HWY_ASSERT(k != 0);
|
||||
HWY_ASSERT(k <= vocab_size);
|
||||
// TODO: Optimize, potentially using new VQSort PartialSort.
|
||||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||
top_k.fill(-std::numeric_limits<float>::infinity());
|
||||
std::array<int, k> indices{};
|
||||
// Sorted from highest [0], to lowest [k-1]
|
||||
std::vector<float> top_k(k, -std::numeric_limits<float>::infinity());
|
||||
std::vector<int> indices(k);
|
||||
size_t num_accepted = 0;
|
||||
for (size_t i = 0; i < vocab_size; ++i) {
|
||||
if (probabilities[i] < top_k[k - 1]) continue;
|
||||
|
|
@ -727,7 +726,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
}
|
||||
}
|
||||
HWY_ASSERT(k <= num_accepted);
|
||||
return indices[create_distribution<k>(top_k, temperature)(gen)];
|
||||
return indices[create_distribution(top_k, temperature)(gen)];
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -387,8 +387,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
}
|
||||
|
||||
void TestRopeAndMulBy() {
|
||||
using Config = ConfigGemma2_9B<float>;
|
||||
int dim_qkv = Config::kQKVDim;
|
||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
RowVectorBatch<float> x(1, dim_qkv);
|
||||
|
||||
std::mt19937 gen;
|
||||
|
|
@ -400,15 +400,15 @@ void TestRopeAndMulBy() {
|
|||
x.All()[i] = random_float();
|
||||
}
|
||||
|
||||
const float qmul = ChooseQueryScale<Config>();
|
||||
const float qmul = ChooseQueryScale(config);
|
||||
const float kmul = 1.0;
|
||||
|
||||
std::vector<float> qexpected(dim_qkv);
|
||||
std::vector<float> qactual(dim_qkv);
|
||||
std::vector<float> kexpected(dim_qkv);
|
||||
std::vector<float> kactual(dim_qkv);
|
||||
RowVectorBatch<float> inv_timescale =
|
||||
gcpp::Activations::CreateInvTimescale<Config>();
|
||||
RowVectorBatch<float> inv_timescale = gcpp::Activations::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk);
|
||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||
for (int pos = 1; pos < 500; pos++) {
|
||||
// Rope'd Q embeddings
|
||||
|
|
@ -571,20 +571,20 @@ void TestSampleTopK() {
|
|||
float temperature = 1.0f;
|
||||
// SampleTopK<1> should return the argmax.
|
||||
std::function<bool(int, float)> accept_token;
|
||||
int sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
|
||||
accept_token);
|
||||
int sample =
|
||||
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
|
||||
EXPECT_EQ(sample, 51); // Last is largest.
|
||||
// Only accept even tokens, expect the last (largest) even index.
|
||||
accept_token = [](int i, float) { return i % 2 == 0; };
|
||||
sample = SampleTopK<1>(logits.data(), kSize, gen, temperature,
|
||||
accept_token);
|
||||
sample =
|
||||
SampleTopK(logits.data(), /*k=*/1, kSize, gen, temperature, accept_token);
|
||||
EXPECT_EQ(sample, 50); // Last even index.
|
||||
// Reset the logits to a positive, increasing sequence and take Softmax.
|
||||
std::iota(logits.begin(), logits.end(), 1.0f);
|
||||
Softmax(logits.data(), kSize);
|
||||
// Sample from the top 3, expect one of the top 3 even indices.
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
accept_token);
|
||||
EXPECT_TRUE(sample == 50 || sample == 48 || sample == 46);
|
||||
}
|
||||
|
|
@ -592,7 +592,7 @@ void TestSampleTopK() {
|
|||
// even for k=3.
|
||||
temperature = 0.0f;
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
sample = SampleTopK<3>(logits.data(), kSize, gen, temperature,
|
||||
sample = SampleTopK(logits.data(), /*k=*/3, kSize, gen, temperature,
|
||||
accept_token);
|
||||
EXPECT_EQ(sample, 50);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -189,6 +189,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
const ModelInfo& Info() const { return info_; }
|
||||
|
||||
private:
|
||||
// TODO(rays): remove this. Eventually ModelConfig will be loaded from the
|
||||
// weights file, so we can remove the need for this struct entirely.
|
||||
ModelInfo info_;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue