mirror of https://github.com/google/gemma.cpp.git
parent
9d83ff202e
commit
4ab601da10
|
|
@ -189,6 +189,7 @@ constexpr bool IsNuqStream() {
|
|||
enum class PromptWrapping {
|
||||
GEMMA_IT,
|
||||
GEMMA_PT,
|
||||
GEMMA_VLM,
|
||||
PALIGEMMA,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ struct Activations {
|
|||
|
||||
// Rope
|
||||
RowVectorBatch<float> inv_timescale;
|
||||
RowVectorBatch<float> inv_timescale_global;
|
||||
|
||||
// Dynamic because no default ctor and only initialized in `Allocate`.
|
||||
MatMulEnv* env;
|
||||
|
|
@ -108,6 +109,8 @@ struct Activations {
|
|||
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
||||
post_qk == PostQKType::HalfRope);
|
||||
inv_timescale_global =
|
||||
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||
|
||||
this->env = env;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ constexpr const char* kModelFlags[] = {
|
|||
"paligemma2-3b-448", // PaliGemma2 3B 448
|
||||
"paligemma2-10b-224", // PaliGemma2 10B 224
|
||||
"paligemma2-10b-448", // PaliGemma2 10B 448
|
||||
"gemma3-4b", // Gemma3 4B
|
||||
"gemma3-1b", // Gemma3 1B
|
||||
"gemma3-12b", // Gemma3 12B
|
||||
"gemma3-27b", // Gemma3 27B
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
|
|
@ -59,6 +63,10 @@ constexpr Model kModelTypes[] = {
|
|||
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
|
||||
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
|
||||
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
|
||||
Model::GEMMA3_4B, // Gemma3 4B
|
||||
Model::GEMMA3_1B, // Gemma3 1B
|
||||
Model::GEMMA3_12B, // Gemma3 12B
|
||||
Model::GEMMA3_27B, // Gemma3 27B
|
||||
};
|
||||
constexpr PromptWrapping kPromptWrapping[] = {
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
|
||||
|
|
@ -71,6 +79,10 @@ constexpr PromptWrapping kPromptWrapping[] = {
|
|||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 4B
|
||||
PromptWrapping::GEMMA_PT, // Gemma3 1B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 12B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 27B
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags = std::size(kModelFlags);
|
||||
|
|
|
|||
188
gemma/configs.cc
188
gemma/configs.cc
|
|
@ -328,6 +328,186 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
|
|||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigBaseGemmaV3() {
|
||||
ModelConfig config = ConfigNoSSM();
|
||||
config.att_cap = 0.0f;
|
||||
config.final_cap = 0.0f;
|
||||
return config;
|
||||
}
|
||||
|
||||
// 1B does not include a vision encoder.
|
||||
static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 6912;
|
||||
config.heads = 4;
|
||||
config.kv_heads = 1;
|
||||
config.qkv_dim = 256;
|
||||
config.optimized_gating = true;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
config.use_qk_norm = true;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_1B() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_1B";
|
||||
config.model = Model::GEMMA3_1B;
|
||||
config.model_dim = 1152;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
||||
config.layer_configs = {26, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
||||
{512, 512, 512, 512, 512, config.seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 8 * 2560 / 2; // = 10240
|
||||
config.heads = 8;
|
||||
config.kv_heads = 4;
|
||||
config.qkv_dim = 256;
|
||||
config.optimized_gating = true;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
config.use_qk_norm = true;
|
||||
return config;
|
||||
}
|
||||
|
||||
// Until we have the SigLIP checkpoints included, we use the LM config directly.
|
||||
static ModelConfig ConfigGemma3_4B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
||||
config.layer_configs = {34, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_4B() {
|
||||
ModelConfig config = ConfigGemma3_4B_LM();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vit_config.pool_dim = 4;
|
||||
const size_t num_patches =
|
||||
config.vit_config.image_size / config.vit_config.patch_width;
|
||||
config.vit_config.seq_len = (num_patches * num_patches);
|
||||
// The above resets optimized gating to false; for Gemma 3 it should be true.
|
||||
for (auto& layer_config : config.layer_configs) {
|
||||
layer_config.optimized_gating = true;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 15360;
|
||||
config.heads = 16;
|
||||
config.kv_heads = 8;
|
||||
config.qkv_dim = 256;
|
||||
config.optimized_gating = true;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
config.use_qk_norm = true;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_12B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.model_dim = 3840;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
||||
config.layer_configs = {48, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_12B() {
|
||||
ModelConfig config = ConfigGemma3_12B_LM();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vit_config.pool_dim = 4;
|
||||
const size_t num_patches =
|
||||
config.vit_config.image_size / config.vit_config.patch_width;
|
||||
config.vit_config.seq_len = (num_patches * num_patches);
|
||||
// The above resets optimized gating to false; for Gemma 3 it should be true.
|
||||
for (auto& layer_config : config.layer_configs) {
|
||||
layer_config.optimized_gating = true;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) {
|
||||
LayerConfig config;
|
||||
config.model_dim = model_dim;
|
||||
config.ff_hidden_dim = 21504;
|
||||
config.heads = 32;
|
||||
config.kv_heads = 16;
|
||||
config.qkv_dim = 128;
|
||||
config.optimized_gating = true;
|
||||
config.post_norm = PostNormType::Scale;
|
||||
config.use_qk_norm = true;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_27B_LM() {
|
||||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.model_dim = 5376;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
||||
config.layer_configs = {62, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
static ModelConfig ConfigGemma3_27B() {
|
||||
ModelConfig config = ConfigGemma3_27B_LM();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vit_config.pool_dim = 4;
|
||||
const size_t num_patches =
|
||||
config.vit_config.image_size / config.vit_config.patch_width;
|
||||
config.vit_config.seq_len = (num_patches * num_patches);
|
||||
// The above resets optimized gating to false; for Gemma 3 it should be true.
|
||||
for (auto& layer_config : config.layer_configs) {
|
||||
layer_config.optimized_gating = true;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
ModelConfig ConfigFromModel(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
|
|
@ -356,6 +536,14 @@ ModelConfig ConfigFromModel(Model model) {
|
|||
return ConfigPaliGemma2_10B_224();
|
||||
case Model::PALIGEMMA2_10B_448:
|
||||
return ConfigPaliGemma2_10B_448();
|
||||
case Model::GEMMA3_4B:
|
||||
return ConfigGemma3_4B();
|
||||
case Model::GEMMA3_1B:
|
||||
return ConfigGemma3_1B();
|
||||
case Model::GEMMA3_12B:
|
||||
return ConfigGemma3_12B();
|
||||
case Model::GEMMA3_27B:
|
||||
return ConfigGemma3_27B();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -151,6 +151,10 @@ enum class Model {
|
|||
PALIGEMMA2_3B_448,
|
||||
PALIGEMMA2_10B_224,
|
||||
PALIGEMMA2_10B_448,
|
||||
GEMMA3_4B,
|
||||
GEMMA3_1B,
|
||||
GEMMA3_12B,
|
||||
GEMMA3_27B,
|
||||
};
|
||||
|
||||
// Allows the Model enum to be iterated over.
|
||||
|
|
@ -159,7 +163,8 @@ static constexpr Model kAllModels[] = {
|
|||
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
|
||||
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
|
||||
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
|
||||
Model::PALIGEMMA2_10B_448,
|
||||
Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B,
|
||||
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
||||
};
|
||||
|
||||
inline bool EnumValid(Model model) {
|
||||
|
|
@ -202,6 +207,7 @@ struct LayerConfig : public IFields {
|
|||
visitor(type);
|
||||
visitor(activation);
|
||||
visitor(post_qk);
|
||||
visitor(use_qk_norm);
|
||||
}
|
||||
|
||||
uint32_t model_dim = 0;
|
||||
|
|
@ -218,6 +224,7 @@ struct LayerConfig : public IFields {
|
|||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
ActivationType activation = ActivationType::Gelu;
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
bool use_qk_norm = false;
|
||||
};
|
||||
|
||||
// Dimensions related to image processing.
|
||||
|
|
|
|||
|
|
@ -219,6 +219,17 @@ class GemmaAttention {
|
|||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||
bool is_global_layer =
|
||||
activations_.weights_config.attention_window_sizes[layer] ==
|
||||
activations_.seq_len;
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer &&
|
||||
(activations_.weights_config.model == Model::GEMMA3_4B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_12B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_27B ||
|
||||
activations_.weights_config.model == Model::GEMMA3_1B)) {
|
||||
inv_timescale = activations_.inv_timescale_global.Const();
|
||||
}
|
||||
// PostQKType::Rope
|
||||
(void)layer;
|
||||
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
||||
|
|
@ -324,6 +335,10 @@ class GemmaAttention {
|
|||
}
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer_weights_.key_norm_scale.data()) {
|
||||
RMSNormInplace(layer_weights_.key_norm_scale.data(), kv,
|
||||
qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||
});
|
||||
}
|
||||
|
|
@ -411,6 +426,10 @@ class GemmaAttention {
|
|||
|
||||
// Apply rope and scaling to Q.
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
if (layer_weights_.query_norm_scale.data()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.data(), q,
|
||||
qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
||||
const size_t start_pos = StartPos(pos, layer_);
|
||||
|
|
@ -590,6 +609,7 @@ class VitAttention {
|
|||
RowPtrFromBatch(qkv));
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
|
|
@ -598,35 +618,56 @@ class VitAttention {
|
|||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
|
||||
RowVectorBatch<float> Q =
|
||||
AllocateAlignedRows<float>(Extents2D(num_tokens_, qkv_dim));
|
||||
RowVectorBatch<float> K =
|
||||
AllocateAlignedRows<float>(Extents2D(seq_len, qkv_dim));
|
||||
RowVectorBatch<float> C(Extents2D(num_tokens_, seq_len));
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
hwy::ZeroBytes(activations_.att_out.All(),
|
||||
num_tokens_ * heads * qkv_dim * sizeof(float));
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t token = task;
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(token) + head * activations_.seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t seq_idx = task;
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len);
|
||||
// Compute weighted sum of v into att_out.
|
||||
activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
|
||||
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
|
||||
RowPtrFromBatch(C));
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Batch(task);
|
||||
Softmax(c, C.Cols());
|
||||
});
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
size_t token = task;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(token) + head * qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.q.Batch(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
float* HWY_RESTRICT v =
|
||||
activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||
|
|
@ -784,14 +825,30 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
|||
// `batch_idx` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
//
|
||||
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
|
||||
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
||||
// if -2 locations with appropriate begin/end image tokens are created by the
|
||||
// calling application.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
RowVectorBatch<float>& x,
|
||||
const ImageTokens* image_tokens) {
|
||||
const ImageTokens* image_tokens,
|
||||
size_t& image_token_position) {
|
||||
// Image tokens just need to be copied.
|
||||
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||
image_tokens != nullptr && token == -2 &&
|
||||
image_token_position < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(image_token_position),
|
||||
x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0]));
|
||||
image_token_position++;
|
||||
return;
|
||||
}
|
||||
|
||||
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||
image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
||||
x.Cols() * sizeof(x.Const()[0]));
|
||||
return;
|
||||
|
|
@ -816,6 +873,21 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
|||
}
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||
// This version of the function doesn't track internal image token position.
|
||||
template <typename T>
|
||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||
size_t pos_in_prompt,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
RowVectorBatch<float>& x,
|
||||
const ImageTokens* image_tokens) {
|
||||
size_t image_token_position = 0;
|
||||
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
|
||||
image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
template <typename Weights, typename T>
|
||||
HWY_NOINLINE void ResidualConnection(
|
||||
size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
||||
|
|
@ -990,12 +1062,13 @@ HWY_NOINLINE void Prefill(
|
|||
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
|
||||
|
||||
// Fill activations.x (much faster than TransformerLayer).
|
||||
size_t image_token_position = 0;
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
EmbedToken(token, ti, pos, pos_in_prompt, weights, activations.x,
|
||||
runtime_config.image_tokens);
|
||||
EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x,
|
||||
runtime_config.image_tokens, image_token_position);
|
||||
}
|
||||
|
||||
// Transformer with one batch of tokens from a single query.
|
||||
|
|
@ -1104,8 +1177,14 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
weights.vit_encoder_norm_bias.data_scale1(),
|
||||
activations.x.All(), vit_model_dim);
|
||||
|
||||
activations.x = AvgPool4x4(activations.x);
|
||||
|
||||
// Apply soft embedding norm before input projection.
|
||||
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
|
||||
vit_model_dim);
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul(ConstMatFromBatch(num_tokens, activations.x),
|
||||
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||
weights.vit_img_head_bias.data_scale1(), *activations.env,
|
||||
RowPtrFromBatch(image_tokens));
|
||||
|
|
@ -1133,10 +1212,11 @@ HWY_NOINLINE void Transformer(
|
|||
}
|
||||
}
|
||||
|
||||
size_t image_token_position = 0;
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
EmbedToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
|
||||
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
|
||||
/*pos_in_prompt=*/0, weights, activations.x,
|
||||
/*image_tokens=*/nullptr);
|
||||
/*image_tokens=*/nullptr, image_token_position);
|
||||
}
|
||||
|
||||
for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) {
|
||||
|
|
@ -1440,7 +1520,8 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
|||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
ModelConfig vit_config = GetVitConfig(model.Config());
|
||||
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config);
|
||||
prefill_activations.Allocate(vit_config.seq_len, env);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
|
|
|
|||
19
gemma/run.cc
19
gemma/run.cc
|
|
@ -94,10 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
if (have_image) {
|
||||
image_tokens =
|
||||
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens = ImageTokens(Extents2D(
|
||||
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
|
|
@ -190,7 +192,15 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
prompt =
|
||||
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
|
||||
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
|
||||
}
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
|
|
@ -209,9 +219,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
|
||||
// Prepare for the next turn.
|
||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(args, gen);
|
||||
} else {
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
// https://arxiv.org/pdf/2408.00118
|
||||
|
|
|
|||
|
|
@ -64,14 +64,14 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
|||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_bias",
|
||||
.source_names = {"img/head/bias"},
|
||||
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
|
||||
.axes = {0},
|
||||
.shape = {config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "img_head_kernel",
|
||||
.source_names = {"img/head/kernel"},
|
||||
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
|
||||
.axes = {1, 0},
|
||||
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
|
|
@ -84,12 +84,21 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
|||
config.vit_config.model_dim},
|
||||
.min_size = Type::kF32,
|
||||
},
|
||||
// RMS norm applied to soft tokens prior to pos embedding.
|
||||
TensorInfo{
|
||||
.name = "mm_embed_norm",
|
||||
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Returns the tensors for the given image layer config.
|
||||
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||
const LayerConfig& layer_config) {
|
||||
const LayerConfig& layer_config,
|
||||
const int img_layer_idx) {
|
||||
return {
|
||||
// Vit layers.
|
||||
TensorInfo{
|
||||
|
|
@ -207,28 +216,40 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
|||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_0_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_0/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_bias",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/bias"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "ln_1_scale",
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
|
||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
|
||||
"img/Transformer/encoderblock_" +
|
||||
std::to_string(img_layer_idx) +
|
||||
"/LayerNorm_1/scale"},
|
||||
.axes = {0},
|
||||
.shape = {config.vit_config.model_dim},
|
||||
.min_size = Type::kBF16,
|
||||
|
|
@ -241,6 +262,20 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
|||
const LayerConfig& layer_config,
|
||||
bool reshape_att) {
|
||||
std::vector<TensorInfo> tensors = {
|
||||
TensorInfo{
|
||||
.name = "key_norm",
|
||||
.source_names = {"attn/_key_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "query_norm",
|
||||
.source_names = {"attn/_query_norm/scale"},
|
||||
.axes = {0},
|
||||
.shape = {layer_config.qkv_dim},
|
||||
.min_size = Type::kBF16,
|
||||
},
|
||||
TensorInfo{
|
||||
.name = "qkv1_w",
|
||||
.source_names = {"attn/q_einsum/w"},
|
||||
|
|
@ -529,7 +564,7 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
|||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||
img_layer_idx < config.vit_config.layer_configs.size()) {
|
||||
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||
tensors_ = ImageLayerTensors(config, layer_config);
|
||||
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
||||
} else if (0 <= llm_layer_idx &&
|
||||
llm_layer_idx < config.layer_configs.size()) {
|
||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||
|
|
|
|||
|
|
@ -127,7 +127,9 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
if (info.wrapping == PromptWrapping::PALIGEMMA
|
||||
// || info.wrapping == PromptWrapping::GEMMA_VLM
|
||||
) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
|
|
@ -136,4 +138,33 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||
size_t pos, std::vector<int>& tokens,
|
||||
size_t image_batch_size, size_t max_image_batch_size) {
|
||||
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
|
||||
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
|
||||
std::string begin_image_prompt = "\n\n<start_of_image>";
|
||||
std::vector<int> begin_image_tokens =
|
||||
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
|
||||
|
||||
std::string end_image_prompt = "<end_of_image>\n\n";
|
||||
std::vector<int> end_image_tokens =
|
||||
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
|
||||
|
||||
for (size_t i = 0; i < num_images; ++i) {
|
||||
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
|
||||
begin_image_tokens.end());
|
||||
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
|
||||
-2);
|
||||
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
|
||||
end_image_tokens.begin(), end_image_tokens.end());
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt);
|
||||
|
||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||
size_t pos, std::vector<int>& tokens,
|
||||
size_t image_batch_size, size_t max_image_batch_size);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||
|
|
|
|||
|
|
@ -99,6 +99,8 @@ struct LayerWeightsPtrs {
|
|||
ffw_gating_biases("ffw_gat_b", tensor_index),
|
||||
ffw_output_biases("ffw_out_b", tensor_index),
|
||||
att_weights("att_w", tensor_index),
|
||||
key_norm_scale("key_norm", tensor_index),
|
||||
query_norm_scale("query_norm", tensor_index),
|
||||
layer_config(config) {}
|
||||
~LayerWeightsPtrs() = default;
|
||||
|
||||
|
|
@ -204,6 +206,9 @@ struct LayerWeightsPtrs {
|
|||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
}
|
||||
|
||||
ArrayT<WeightF32OrBF16> key_norm_scale;
|
||||
ArrayT<WeightF32OrBF16> query_norm_scale;
|
||||
|
||||
// Used by ForEachTensor for per-layer tensors.
|
||||
#define GEMMA_CALL_FUNC(member) \
|
||||
{ \
|
||||
|
|
@ -282,6 +287,10 @@ struct LayerWeightsPtrs {
|
|||
GEMMA_CALL_FUNC(post_attention_norm_scale);
|
||||
GEMMA_CALL_FUNC(post_ffw_norm_scale);
|
||||
}
|
||||
if (ptrs[0]->layer_config.use_qk_norm) {
|
||||
GEMMA_CALL_FUNC(key_norm_scale);
|
||||
GEMMA_CALL_FUNC(query_norm_scale);
|
||||
}
|
||||
|
||||
if (ptrs[0]->layer_config.ff_biases) {
|
||||
GEMMA_CALL_FUNC(ffw_gating_biases);
|
||||
|
|
@ -332,6 +341,7 @@ struct ModelWeightsPtrs {
|
|||
vit_img_pos_embedding("img_pos_emb", tensor_index),
|
||||
vit_img_head_bias("img_head_bias", tensor_index),
|
||||
vit_img_head_kernel("img_head_kernel", tensor_index),
|
||||
mm_embed_norm("mm_embed_norm", tensor_index),
|
||||
scale_names(config.scale_names),
|
||||
weights_config(config) {
|
||||
c_layers.reserve(config.layer_configs.size());
|
||||
|
|
@ -372,6 +382,8 @@ struct ModelWeightsPtrs {
|
|||
MatPtrT<float> vit_img_head_bias;
|
||||
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
||||
|
||||
MatPtrT<WeightF32OrBF16> mm_embed_norm;
|
||||
|
||||
std::unordered_set<std::string> scale_names;
|
||||
|
||||
const ModelConfig& weights_config;
|
||||
|
|
@ -488,6 +500,7 @@ struct ModelWeightsPtrs {
|
|||
GEMMA_CALL_FUNC(vit_img_pos_embedding);
|
||||
GEMMA_CALL_FUNC(vit_img_head_bias);
|
||||
GEMMA_CALL_FUNC(vit_img_head_kernel);
|
||||
GEMMA_CALL_FUNC(mm_embed_norm);
|
||||
}
|
||||
|
||||
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ class MMStorage {
|
|||
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
|
||||
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
|
||||
// per package and 512 MiB, respectively.
|
||||
static constexpr size_t kMaxM = 2048;
|
||||
static constexpr size_t kMaxM = 4096;
|
||||
static constexpr size_t kMaxK = 64 * 1024;
|
||||
static constexpr size_t kMaxN = 256 * 1024;
|
||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||
|
|
|
|||
|
|
@ -802,6 +802,48 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
.prob = topk_logits[topk_sampled_index]};
|
||||
}
|
||||
|
||||
// Performs 4x4 average pooling across row vectors
|
||||
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
|
||||
// Each output row is the average of a 4x4 block of input rows
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
||||
Extents2D extents = input.Extents();
|
||||
// Input validation
|
||||
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||
// Create output with 256 rows and same number of columns
|
||||
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
||||
RowVectorBatch<T> result(Extents2D{out_rows, extents.cols});
|
||||
const size_t input_dim = 64; // Input is 64×64
|
||||
const size_t output_dim = 16; // Output is 16×16
|
||||
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
|
||||
for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) {
|
||||
size_t out_idx = out_row_idx * output_dim + out_col_idx;
|
||||
T* output_row = result.Batch(out_idx);
|
||||
// Initialize output row to zeros
|
||||
std::fill(output_row, output_row + extents.cols, 0);
|
||||
// Average 16 row vectors from a 4x4 block
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
for (size_t j = 0; j < 4; ++j) {
|
||||
size_t in_row_idx = out_row_idx * 4 + i;
|
||||
size_t in_col_idx = out_col_idx * 4 + j;
|
||||
size_t in_idx = in_row_idx * input_dim + in_col_idx;
|
||||
const T* input_row = input.Batch(in_idx);
|
||||
// Add each input row to the output
|
||||
// TODO(philculliton): use AddFrom in ops-inl for a vectorized loop.
|
||||
for (size_t col = 0; col < extents.cols; ++col) {
|
||||
output_row[col] += input_row[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
// Divide by 16 to get the average
|
||||
for (size_t col = 0; col < extents.cols; ++col) {
|
||||
output_row[col] *= T{0.0625};
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
Loading…
Reference in New Issue