mirror of https://github.com/google/gemma.cpp.git
parent
9d83ff202e
commit
4ab601da10
|
|
@ -189,6 +189,7 @@ constexpr bool IsNuqStream() {
|
||||||
enum class PromptWrapping {
|
enum class PromptWrapping {
|
||||||
GEMMA_IT,
|
GEMMA_IT,
|
||||||
GEMMA_PT,
|
GEMMA_PT,
|
||||||
|
GEMMA_VLM,
|
||||||
PALIGEMMA,
|
PALIGEMMA,
|
||||||
kSentinel // must be last
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ struct Activations {
|
||||||
|
|
||||||
// Rope
|
// Rope
|
||||||
RowVectorBatch<float> inv_timescale;
|
RowVectorBatch<float> inv_timescale;
|
||||||
|
RowVectorBatch<float> inv_timescale_global;
|
||||||
|
|
||||||
// Dynamic because no default ctor and only initialized in `Allocate`.
|
// Dynamic because no default ctor and only initialized in `Allocate`.
|
||||||
MatMulEnv* env;
|
MatMulEnv* env;
|
||||||
|
|
@ -108,6 +109,8 @@ struct Activations {
|
||||||
|
|
||||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
||||||
post_qk == PostQKType::HalfRope);
|
post_qk == PostQKType::HalfRope);
|
||||||
|
inv_timescale_global =
|
||||||
|
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||||
|
|
||||||
this->env = env;
|
this->env = env;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,10 @@ constexpr const char* kModelFlags[] = {
|
||||||
"paligemma2-3b-448", // PaliGemma2 3B 448
|
"paligemma2-3b-448", // PaliGemma2 3B 448
|
||||||
"paligemma2-10b-224", // PaliGemma2 10B 224
|
"paligemma2-10b-224", // PaliGemma2 10B 224
|
||||||
"paligemma2-10b-448", // PaliGemma2 10B 448
|
"paligemma2-10b-448", // PaliGemma2 10B 448
|
||||||
|
"gemma3-4b", // Gemma3 4B
|
||||||
|
"gemma3-1b", // Gemma3 1B
|
||||||
|
"gemma3-12b", // Gemma3 12B
|
||||||
|
"gemma3-27b", // Gemma3 27B
|
||||||
};
|
};
|
||||||
constexpr Model kModelTypes[] = {
|
constexpr Model kModelTypes[] = {
|
||||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||||
|
|
@ -59,6 +63,10 @@ constexpr Model kModelTypes[] = {
|
||||||
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
|
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
|
||||||
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
|
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
|
||||||
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
|
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
|
||||||
|
Model::GEMMA3_4B, // Gemma3 4B
|
||||||
|
Model::GEMMA3_1B, // Gemma3 1B
|
||||||
|
Model::GEMMA3_12B, // Gemma3 12B
|
||||||
|
Model::GEMMA3_27B, // Gemma3 27B
|
||||||
};
|
};
|
||||||
constexpr PromptWrapping kPromptWrapping[] = {
|
constexpr PromptWrapping kPromptWrapping[] = {
|
||||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
|
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
|
||||||
|
|
@ -71,6 +79,10 @@ constexpr PromptWrapping kPromptWrapping[] = {
|
||||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
|
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
|
||||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
||||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
||||||
|
PromptWrapping::GEMMA_VLM, // Gemma3 4B
|
||||||
|
PromptWrapping::GEMMA_PT, // Gemma3 1B
|
||||||
|
PromptWrapping::GEMMA_VLM, // Gemma3 12B
|
||||||
|
PromptWrapping::GEMMA_VLM, // Gemma3 27B
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr size_t kNumModelFlags = std::size(kModelFlags);
|
constexpr size_t kNumModelFlags = std::size(kModelFlags);
|
||||||
|
|
|
||||||
188
gemma/configs.cc
188
gemma/configs.cc
|
|
@ -328,6 +328,186 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigBaseGemmaV3() {
|
||||||
|
ModelConfig config = ConfigNoSSM();
|
||||||
|
config.att_cap = 0.0f;
|
||||||
|
config.final_cap = 0.0f;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1B does not include a vision encoder.
|
||||||
|
static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 6912;
|
||||||
|
config.heads = 4;
|
||||||
|
config.kv_heads = 1;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
config.optimized_gating = true;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
config.use_qk_norm = true;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigGemma3_1B() {
|
||||||
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
|
config.model_name = "Gemma3_1B";
|
||||||
|
config.model = Model::GEMMA3_1B;
|
||||||
|
config.model_dim = 1152;
|
||||||
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
|
config.seq_len = 32 * 1024;
|
||||||
|
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
||||||
|
config.layer_configs = {26, layer_config};
|
||||||
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
// interleaved local / global attention
|
||||||
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
||||||
|
{512, 512, 512, 512, 512, config.seq_len});
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 8 * 2560 / 2; // = 10240
|
||||||
|
config.heads = 8;
|
||||||
|
config.kv_heads = 4;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
config.optimized_gating = true;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
config.use_qk_norm = true;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Until we have the SigLIP checkpoints included, we use the LM config directly.
|
||||||
|
static ModelConfig ConfigGemma3_4B_LM() {
|
||||||
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
|
config.model_name = "Gemma3_4B";
|
||||||
|
config.model = Model::GEMMA3_4B;
|
||||||
|
config.model_dim = 2560;
|
||||||
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
|
config.seq_len = 32 * 1024;
|
||||||
|
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
||||||
|
config.layer_configs = {34, layer_config};
|
||||||
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
// interleaved local / global attention
|
||||||
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
||||||
|
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigGemma3_4B() {
|
||||||
|
ModelConfig config = ConfigGemma3_4B_LM();
|
||||||
|
config.model_name = "Gemma3_4B";
|
||||||
|
config.model = Model::GEMMA3_4B;
|
||||||
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
|
config.vocab_size = 262144;
|
||||||
|
config.vit_config.pool_dim = 4;
|
||||||
|
const size_t num_patches =
|
||||||
|
config.vit_config.image_size / config.vit_config.patch_width;
|
||||||
|
config.vit_config.seq_len = (num_patches * num_patches);
|
||||||
|
// The above resets optimized gating to false; for Gemma 3 it should be true.
|
||||||
|
for (auto& layer_config : config.layer_configs) {
|
||||||
|
layer_config.optimized_gating = true;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 15360;
|
||||||
|
config.heads = 16;
|
||||||
|
config.kv_heads = 8;
|
||||||
|
config.qkv_dim = 256;
|
||||||
|
config.optimized_gating = true;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
config.use_qk_norm = true;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigGemma3_12B_LM() {
|
||||||
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
|
config.model_name = "Gemma3_12B";
|
||||||
|
config.model = Model::GEMMA3_12B;
|
||||||
|
config.model_dim = 3840;
|
||||||
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
|
config.seq_len = 32 * 1024;
|
||||||
|
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
||||||
|
config.layer_configs = {48, layer_config};
|
||||||
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
// interleaved local / global attention
|
||||||
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
||||||
|
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigGemma3_12B() {
|
||||||
|
ModelConfig config = ConfigGemma3_12B_LM();
|
||||||
|
config.model_name = "Gemma3_12B";
|
||||||
|
config.model = Model::GEMMA3_12B;
|
||||||
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
|
config.vocab_size = 262144;
|
||||||
|
config.vit_config.pool_dim = 4;
|
||||||
|
const size_t num_patches =
|
||||||
|
config.vit_config.image_size / config.vit_config.patch_width;
|
||||||
|
config.vit_config.seq_len = (num_patches * num_patches);
|
||||||
|
// The above resets optimized gating to false; for Gemma 3 it should be true.
|
||||||
|
for (auto& layer_config : config.layer_configs) {
|
||||||
|
layer_config.optimized_gating = true;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) {
|
||||||
|
LayerConfig config;
|
||||||
|
config.model_dim = model_dim;
|
||||||
|
config.ff_hidden_dim = 21504;
|
||||||
|
config.heads = 32;
|
||||||
|
config.kv_heads = 16;
|
||||||
|
config.qkv_dim = 128;
|
||||||
|
config.optimized_gating = true;
|
||||||
|
config.post_norm = PostNormType::Scale;
|
||||||
|
config.use_qk_norm = true;
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigGemma3_27B_LM() {
|
||||||
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
|
config.model_name = "Gemma3_27B";
|
||||||
|
config.model = Model::GEMMA3_27B;
|
||||||
|
config.model_dim = 5376;
|
||||||
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
|
config.seq_len = 32 * 1024;
|
||||||
|
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
||||||
|
config.layer_configs = {62, layer_config};
|
||||||
|
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||||
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
|
// interleaved local / global attention
|
||||||
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
||||||
|
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ModelConfig ConfigGemma3_27B() {
|
||||||
|
ModelConfig config = ConfigGemma3_27B_LM();
|
||||||
|
config.model_name = "Gemma3_27B";
|
||||||
|
config.model = Model::GEMMA3_27B;
|
||||||
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
|
config.vocab_size = 262144;
|
||||||
|
config.vit_config.pool_dim = 4;
|
||||||
|
const size_t num_patches =
|
||||||
|
config.vit_config.image_size / config.vit_config.patch_width;
|
||||||
|
config.vit_config.seq_len = (num_patches * num_patches);
|
||||||
|
// The above resets optimized gating to false; for Gemma 3 it should be true.
|
||||||
|
for (auto& layer_config : config.layer_configs) {
|
||||||
|
layer_config.optimized_gating = true;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
ModelConfig ConfigFromModel(Model model) {
|
ModelConfig ConfigFromModel(Model model) {
|
||||||
switch (model) {
|
switch (model) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
|
|
@ -356,6 +536,14 @@ ModelConfig ConfigFromModel(Model model) {
|
||||||
return ConfigPaliGemma2_10B_224();
|
return ConfigPaliGemma2_10B_224();
|
||||||
case Model::PALIGEMMA2_10B_448:
|
case Model::PALIGEMMA2_10B_448:
|
||||||
return ConfigPaliGemma2_10B_448();
|
return ConfigPaliGemma2_10B_448();
|
||||||
|
case Model::GEMMA3_4B:
|
||||||
|
return ConfigGemma3_4B();
|
||||||
|
case Model::GEMMA3_1B:
|
||||||
|
return ConfigGemma3_1B();
|
||||||
|
case Model::GEMMA3_12B:
|
||||||
|
return ConfigGemma3_12B();
|
||||||
|
case Model::GEMMA3_27B:
|
||||||
|
return ConfigGemma3_27B();
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,10 @@ enum class Model {
|
||||||
PALIGEMMA2_3B_448,
|
PALIGEMMA2_3B_448,
|
||||||
PALIGEMMA2_10B_224,
|
PALIGEMMA2_10B_224,
|
||||||
PALIGEMMA2_10B_448,
|
PALIGEMMA2_10B_448,
|
||||||
|
GEMMA3_4B,
|
||||||
|
GEMMA3_1B,
|
||||||
|
GEMMA3_12B,
|
||||||
|
GEMMA3_27B,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allows the Model enum to be iterated over.
|
// Allows the Model enum to be iterated over.
|
||||||
|
|
@ -159,7 +163,8 @@ static constexpr Model kAllModels[] = {
|
||||||
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
|
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
|
||||||
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
|
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
|
||||||
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
|
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
|
||||||
Model::PALIGEMMA2_10B_448,
|
Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B,
|
||||||
|
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(Model model) {
|
inline bool EnumValid(Model model) {
|
||||||
|
|
@ -202,6 +207,7 @@ struct LayerConfig : public IFields {
|
||||||
visitor(type);
|
visitor(type);
|
||||||
visitor(activation);
|
visitor(activation);
|
||||||
visitor(post_qk);
|
visitor(post_qk);
|
||||||
|
visitor(use_qk_norm);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t model_dim = 0;
|
uint32_t model_dim = 0;
|
||||||
|
|
@ -218,6 +224,7 @@ struct LayerConfig : public IFields {
|
||||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||||
ActivationType activation = ActivationType::Gelu;
|
ActivationType activation = ActivationType::Gelu;
|
||||||
PostQKType post_qk = PostQKType::Rope;
|
PostQKType post_qk = PostQKType::Rope;
|
||||||
|
bool use_qk_norm = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Dimensions related to image processing.
|
// Dimensions related to image processing.
|
||||||
|
|
|
||||||
|
|
@ -219,6 +219,17 @@ class GemmaAttention {
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
const float* inv_timescale = activations_.inv_timescale.Const();
|
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||||
|
bool is_global_layer =
|
||||||
|
activations_.weights_config.attention_window_sizes[layer] ==
|
||||||
|
activations_.seq_len;
|
||||||
|
// TODO: add a config flag instead of hardcoding the model.
|
||||||
|
if (is_global_layer &&
|
||||||
|
(activations_.weights_config.model == Model::GEMMA3_4B ||
|
||||||
|
activations_.weights_config.model == Model::GEMMA3_12B ||
|
||||||
|
activations_.weights_config.model == Model::GEMMA3_27B ||
|
||||||
|
activations_.weights_config.model == Model::GEMMA3_1B)) {
|
||||||
|
inv_timescale = activations_.inv_timescale_global.Const();
|
||||||
|
}
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
(void)layer;
|
(void)layer;
|
||||||
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
||||||
|
|
@ -324,6 +335,10 @@ class GemmaAttention {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply further processing to K.
|
// Apply further processing to K.
|
||||||
|
if (layer_weights_.key_norm_scale.data()) {
|
||||||
|
RMSNormInplace(layer_weights_.key_norm_scale.data(), kv,
|
||||||
|
qkv_dim);
|
||||||
|
}
|
||||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -411,6 +426,10 @@ class GemmaAttention {
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
|
if (layer_weights_.query_norm_scale.data()) {
|
||||||
|
RMSNormInplace(layer_weights_.query_norm_scale.data(), q,
|
||||||
|
qkv_dim);
|
||||||
|
}
|
||||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||||
|
|
||||||
const size_t start_pos = StartPos(pos, layer_);
|
const size_t start_pos = StartPos(pos, layer_);
|
||||||
|
|
@ -590,6 +609,7 @@ class VitAttention {
|
||||||
RowPtrFromBatch(qkv));
|
RowPtrFromBatch(qkv));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(philculliton): transition fully to MatMul.
|
||||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
const size_t heads = layer_config_.heads;
|
const size_t heads = layer_config_.heads;
|
||||||
|
|
@ -598,35 +618,56 @@ class VitAttention {
|
||||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||||
|
|
||||||
// Compute Q.K, softmax, and weighted V.
|
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
|
||||||
pool_.Run(0, layer_config_.heads * num_tokens_,
|
RowVectorBatch<float> Q =
|
||||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
AllocateAlignedRows<float>(Extents2D(num_tokens_, qkv_dim));
|
||||||
const size_t head = task % layer_config_.heads;
|
RowVectorBatch<float> K =
|
||||||
const size_t token = task / layer_config_.heads;
|
AllocateAlignedRows<float>(Extents2D(seq_len, qkv_dim));
|
||||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
RowVectorBatch<float> C(Extents2D(num_tokens_, seq_len));
|
||||||
|
|
||||||
|
// Initialize att_out to zero prior to head loop.
|
||||||
|
hwy::ZeroBytes(activations_.att_out.All(),
|
||||||
|
num_tokens_ * heads * qkv_dim * sizeof(float));
|
||||||
|
|
||||||
|
for (size_t head = 0; head < heads; ++head) {
|
||||||
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
const size_t token = task;
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
||||||
|
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
|
||||||
MulByConst(query_scale, q, qkv_dim);
|
MulByConst(query_scale, q, qkv_dim);
|
||||||
float* HWY_RESTRICT head_att =
|
hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float));
|
||||||
activations_.att.Batch(token) + head * activations_.seq_len;
|
});
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
|
||||||
|
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
const size_t seq_idx = task;
|
||||||
float* HWY_RESTRICT k =
|
float* HWY_RESTRICT k =
|
||||||
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
|
activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim;
|
||||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float));
|
||||||
}
|
});
|
||||||
// SoftMax yields "probabilities" in head_att.
|
|
||||||
Softmax(head_att, seq_len);
|
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||||
// Compute weighted sum of v into att_out.
|
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
|
||||||
|
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
|
||||||
|
RowPtrFromBatch(C));
|
||||||
|
|
||||||
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
float* HWY_RESTRICT c = C.Batch(task);
|
||||||
|
Softmax(c, C.Cols());
|
||||||
|
});
|
||||||
|
|
||||||
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
|
size_t token = task;
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations_.att_out.Batch(token) + head * qkv_dim;
|
activations_.att_out.Batch(token) + head * qkv_dim;
|
||||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
float* HWY_RESTRICT v = activations_.q.Batch(i) +
|
float* HWY_RESTRICT v =
|
||||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim;
|
||||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||||
|
|
@ -784,14 +825,30 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||||
// `batch_idx` indicates which row of `x` to write to.
|
// `batch_idx` indicates which row of `x` to write to.
|
||||||
// `pos` is the *token*'s position, not the start of the batch, because this is
|
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||||
// called for batches of tokens in prefill, but batches of queries in decode.
|
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||||
|
//
|
||||||
|
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
|
||||||
|
// spec) until we run out of image tokens. This allows for a multi-image prompt
|
||||||
|
// if -2 locations with appropriate begin/end image tokens are created by the
|
||||||
|
// calling application.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
||||||
size_t pos_in_prompt,
|
size_t pos_in_prompt,
|
||||||
const ModelWeightsPtrs<T>& weights,
|
const ModelWeightsPtrs<T>& weights,
|
||||||
RowVectorBatch<float>& x,
|
RowVectorBatch<float>& x,
|
||||||
const ImageTokens* image_tokens) {
|
const ImageTokens* image_tokens,
|
||||||
|
size_t& image_token_position) {
|
||||||
// Image tokens just need to be copied.
|
// Image tokens just need to be copied.
|
||||||
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
|
||||||
|
image_tokens != nullptr && token == -2 &&
|
||||||
|
image_token_position < image_tokens->BatchSize()) {
|
||||||
|
hwy::CopyBytes(image_tokens->Batch(image_token_position),
|
||||||
|
x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0]));
|
||||||
|
image_token_position++;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
|
||||||
|
image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
|
||||||
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
|
||||||
x.Cols() * sizeof(x.Const()[0]));
|
x.Cols() * sizeof(x.Const()[0]));
|
||||||
return;
|
return;
|
||||||
|
|
@ -816,6 +873,21 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `batch_idx` indicates which row of `x` to write to.
|
||||||
|
// `pos` is the *token*'s position, not the start of the batch, because this is
|
||||||
|
// called for batches of tokens in prefill, but batches of queries in decode.
|
||||||
|
// This version of the function doesn't track internal image token position.
|
||||||
|
template <typename T>
|
||||||
|
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
|
size_t pos_in_prompt,
|
||||||
|
const ModelWeightsPtrs<T>& weights,
|
||||||
|
RowVectorBatch<float>& x,
|
||||||
|
const ImageTokens* image_tokens) {
|
||||||
|
size_t image_token_position = 0;
|
||||||
|
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
|
||||||
|
image_tokens, image_token_position);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Weights, typename T>
|
template <typename Weights, typename T>
|
||||||
HWY_NOINLINE void ResidualConnection(
|
HWY_NOINLINE void ResidualConnection(
|
||||||
size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
|
||||||
|
|
@ -990,12 +1062,13 @@ HWY_NOINLINE void Prefill(
|
||||||
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
|
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
|
||||||
|
|
||||||
// Fill activations.x (much faster than TransformerLayer).
|
// Fill activations.x (much faster than TransformerLayer).
|
||||||
|
size_t image_token_position = 0;
|
||||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
const size_t pos = queries_pos[qi] + ti;
|
const size_t pos = queries_pos[qi] + ti;
|
||||||
const size_t pos_in_prompt = tbatch_start + ti;
|
const size_t pos_in_prompt = tbatch_start + ti;
|
||||||
const int token = queries_prompt[qi][pos_in_prompt];
|
const int token = queries_prompt[qi][pos_in_prompt];
|
||||||
EmbedToken(token, ti, pos, pos_in_prompt, weights, activations.x,
|
EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x,
|
||||||
runtime_config.image_tokens);
|
runtime_config.image_tokens, image_token_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transformer with one batch of tokens from a single query.
|
// Transformer with one batch of tokens from a single query.
|
||||||
|
|
@ -1104,8 +1177,14 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
||||||
weights.vit_encoder_norm_bias.data_scale1(),
|
weights.vit_encoder_norm_bias.data_scale1(),
|
||||||
activations.x.All(), vit_model_dim);
|
activations.x.All(), vit_model_dim);
|
||||||
|
|
||||||
|
activations.x = AvgPool4x4(activations.x);
|
||||||
|
|
||||||
|
// Apply soft embedding norm before input projection.
|
||||||
|
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
|
||||||
|
vit_model_dim);
|
||||||
|
|
||||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||||
MatMul(ConstMatFromBatch(num_tokens, activations.x),
|
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
||||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||||
weights.vit_img_head_bias.data_scale1(), *activations.env,
|
weights.vit_img_head_bias.data_scale1(), *activations.env,
|
||||||
RowPtrFromBatch(image_tokens));
|
RowPtrFromBatch(image_tokens));
|
||||||
|
|
@ -1133,10 +1212,11 @@ HWY_NOINLINE void Transformer(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t image_token_position = 0;
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
EmbedToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
|
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
|
||||||
/*pos_in_prompt=*/0, weights, activations.x,
|
/*pos_in_prompt=*/0, weights, activations.x,
|
||||||
/*image_tokens=*/nullptr);
|
/*image_tokens=*/nullptr, image_token_position);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) {
|
for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) {
|
||||||
|
|
@ -1440,7 +1520,8 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
||||||
}
|
}
|
||||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||||
ModelConfig vit_config = GetVitConfig(model.Config());
|
ModelConfig vit_config = GetVitConfig(model.Config());
|
||||||
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
|
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config);
|
Activations prefill_activations(vit_config);
|
||||||
prefill_activations.Allocate(vit_config.seq_len, env);
|
prefill_activations.Allocate(vit_config.seq_len, env);
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
|
|
|
||||||
19
gemma/run.cc
19
gemma/run.cc
|
|
@ -94,10 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
Image image;
|
Image image;
|
||||||
ImageTokens image_tokens;
|
ImageTokens image_tokens;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
image_tokens =
|
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||||
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
image_tokens = ImageTokens(Extents2D(
|
||||||
|
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
|
||||||
model.GetModelConfig().model_dim));
|
model.GetModelConfig().model_dim));
|
||||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||||
|
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
|
|
@ -190,7 +192,15 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
runtime_config.image_tokens = &image_tokens;
|
runtime_config.image_tokens = &image_tokens;
|
||||||
|
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||||
|
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
|
||||||
|
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
|
||||||
|
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||||
|
prompt =
|
||||||
|
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
|
||||||
|
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
|
||||||
|
}
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
|
|
@ -209,9 +219,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
|
|
||||||
// Prepare for the next turn.
|
// Prepare for the next turn.
|
||||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
abs_pos = 0; // Start a new turn at position 0.
|
|
||||||
InitGenerator(args, gen);
|
|
||||||
} else {
|
|
||||||
// The last token was either EOS, then it should be ignored because it is
|
// The last token was either EOS, then it should be ignored because it is
|
||||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||||
// https://arxiv.org/pdf/2408.00118
|
// https://arxiv.org/pdf/2408.00118
|
||||||
|
|
|
||||||
|
|
@ -64,14 +64,14 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "img_head_bias",
|
.name = "img_head_bias",
|
||||||
.source_names = {"img/head/bias"},
|
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.model_dim},
|
.shape = {config.model_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "img_head_kernel",
|
.name = "img_head_kernel",
|
||||||
.source_names = {"img/head/kernel"},
|
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
|
||||||
.axes = {1, 0},
|
.axes = {1, 0},
|
||||||
.shape = {config.model_dim, config.vit_config.model_dim},
|
.shape = {config.model_dim, config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
|
|
@ -84,12 +84,21 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
|
||||||
config.vit_config.model_dim},
|
config.vit_config.model_dim},
|
||||||
.min_size = Type::kF32,
|
.min_size = Type::kF32,
|
||||||
},
|
},
|
||||||
|
// RMS norm applied to soft tokens prior to pos embedding.
|
||||||
|
TensorInfo{
|
||||||
|
.name = "mm_embed_norm",
|
||||||
|
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
|
||||||
|
.axes = {0},
|
||||||
|
.shape = {config.vit_config.model_dim},
|
||||||
|
.min_size = Type::kBF16,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the tensors for the given image layer config.
|
// Returns the tensors for the given image layer config.
|
||||||
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
const LayerConfig& layer_config) {
|
const LayerConfig& layer_config,
|
||||||
|
const int img_layer_idx) {
|
||||||
return {
|
return {
|
||||||
// Vit layers.
|
// Vit layers.
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
|
|
@ -207,28 +216,40 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_0_bias",
|
.name = "ln_0_bias",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
|
||||||
|
"img/Transformer/encoderblock_" +
|
||||||
|
std::to_string(img_layer_idx) +
|
||||||
|
"/LayerNorm_0/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_config.model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_0_scale",
|
.name = "ln_0_scale",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
|
||||||
|
"img/Transformer/encoderblock_" +
|
||||||
|
std::to_string(img_layer_idx) +
|
||||||
|
"/LayerNorm_0/scale"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_config.model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_1_bias",
|
.name = "ln_1_bias",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
|
||||||
|
"img/Transformer/encoderblock_" +
|
||||||
|
std::to_string(img_layer_idx) +
|
||||||
|
"/LayerNorm_1/bias"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_config.model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
},
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "ln_1_scale",
|
.name = "ln_1_scale",
|
||||||
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
|
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
|
||||||
|
"img/Transformer/encoderblock_" +
|
||||||
|
std::to_string(img_layer_idx) +
|
||||||
|
"/LayerNorm_1/scale"},
|
||||||
.axes = {0},
|
.axes = {0},
|
||||||
.shape = {config.vit_config.model_dim},
|
.shape = {config.vit_config.model_dim},
|
||||||
.min_size = Type::kBF16,
|
.min_size = Type::kBF16,
|
||||||
|
|
@ -241,6 +262,20 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
|
||||||
const LayerConfig& layer_config,
|
const LayerConfig& layer_config,
|
||||||
bool reshape_att) {
|
bool reshape_att) {
|
||||||
std::vector<TensorInfo> tensors = {
|
std::vector<TensorInfo> tensors = {
|
||||||
|
TensorInfo{
|
||||||
|
.name = "key_norm",
|
||||||
|
.source_names = {"attn/_key_norm/scale"},
|
||||||
|
.axes = {0},
|
||||||
|
.shape = {layer_config.qkv_dim},
|
||||||
|
.min_size = Type::kBF16,
|
||||||
|
},
|
||||||
|
TensorInfo{
|
||||||
|
.name = "query_norm",
|
||||||
|
.source_names = {"attn/_query_norm/scale"},
|
||||||
|
.axes = {0},
|
||||||
|
.shape = {layer_config.qkv_dim},
|
||||||
|
.min_size = Type::kBF16,
|
||||||
|
},
|
||||||
TensorInfo{
|
TensorInfo{
|
||||||
.name = "qkv1_w",
|
.name = "qkv1_w",
|
||||||
.source_names = {"attn/q_einsum/w"},
|
.source_names = {"attn/q_einsum/w"},
|
||||||
|
|
@ -529,7 +564,7 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||||
img_layer_idx < config.vit_config.layer_configs.size()) {
|
img_layer_idx < config.vit_config.layer_configs.size()) {
|
||||||
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||||
tensors_ = ImageLayerTensors(config, layer_config);
|
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
||||||
} else if (0 <= llm_layer_idx &&
|
} else if (0 <= llm_layer_idx &&
|
||||||
llm_layer_idx < config.layer_configs.size()) {
|
llm_layer_idx < config.layer_configs.size()) {
|
||||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,9 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||||
if (info.wrapping == PromptWrapping::PALIGEMMA) {
|
if (info.wrapping == PromptWrapping::PALIGEMMA
|
||||||
|
// || info.wrapping == PromptWrapping::GEMMA_VLM
|
||||||
|
) {
|
||||||
std::vector<int> sep_tokens;
|
std::vector<int> sep_tokens;
|
||||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||||
|
|
@ -136,4 +138,33 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||||
|
size_t pos, std::vector<int>& tokens,
|
||||||
|
size_t image_batch_size, size_t max_image_batch_size) {
|
||||||
|
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
|
||||||
|
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
|
||||||
|
|
||||||
|
std::vector<int> sep_tokens;
|
||||||
|
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||||
|
|
||||||
|
std::string begin_image_prompt = "\n\n<start_of_image>";
|
||||||
|
std::vector<int> begin_image_tokens =
|
||||||
|
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
|
||||||
|
|
||||||
|
std::string end_image_prompt = "<end_of_image>\n\n";
|
||||||
|
std::vector<int> end_image_tokens =
|
||||||
|
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_images; ++i) {
|
||||||
|
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
|
||||||
|
begin_image_tokens.end());
|
||||||
|
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
|
||||||
|
-2);
|
||||||
|
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
|
||||||
|
end_image_tokens.begin(), end_image_tokens.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
const ModelInfo& info, size_t pos,
|
const ModelInfo& info, size_t pos,
|
||||||
std::string& prompt);
|
std::string& prompt);
|
||||||
|
|
||||||
|
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||||
|
size_t pos, std::vector<int>& tokens,
|
||||||
|
size_t image_batch_size, size_t max_image_batch_size);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_
|
||||||
|
|
|
||||||
|
|
@ -99,6 +99,8 @@ struct LayerWeightsPtrs {
|
||||||
ffw_gating_biases("ffw_gat_b", tensor_index),
|
ffw_gating_biases("ffw_gat_b", tensor_index),
|
||||||
ffw_output_biases("ffw_out_b", tensor_index),
|
ffw_output_biases("ffw_out_b", tensor_index),
|
||||||
att_weights("att_w", tensor_index),
|
att_weights("att_w", tensor_index),
|
||||||
|
key_norm_scale("key_norm", tensor_index),
|
||||||
|
query_norm_scale("query_norm", tensor_index),
|
||||||
layer_config(config) {}
|
layer_config(config) {}
|
||||||
~LayerWeightsPtrs() = default;
|
~LayerWeightsPtrs() = default;
|
||||||
|
|
||||||
|
|
@ -204,6 +206,9 @@ struct LayerWeightsPtrs {
|
||||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ArrayT<WeightF32OrBF16> key_norm_scale;
|
||||||
|
ArrayT<WeightF32OrBF16> query_norm_scale;
|
||||||
|
|
||||||
// Used by ForEachTensor for per-layer tensors.
|
// Used by ForEachTensor for per-layer tensors.
|
||||||
#define GEMMA_CALL_FUNC(member) \
|
#define GEMMA_CALL_FUNC(member) \
|
||||||
{ \
|
{ \
|
||||||
|
|
@ -282,6 +287,10 @@ struct LayerWeightsPtrs {
|
||||||
GEMMA_CALL_FUNC(post_attention_norm_scale);
|
GEMMA_CALL_FUNC(post_attention_norm_scale);
|
||||||
GEMMA_CALL_FUNC(post_ffw_norm_scale);
|
GEMMA_CALL_FUNC(post_ffw_norm_scale);
|
||||||
}
|
}
|
||||||
|
if (ptrs[0]->layer_config.use_qk_norm) {
|
||||||
|
GEMMA_CALL_FUNC(key_norm_scale);
|
||||||
|
GEMMA_CALL_FUNC(query_norm_scale);
|
||||||
|
}
|
||||||
|
|
||||||
if (ptrs[0]->layer_config.ff_biases) {
|
if (ptrs[0]->layer_config.ff_biases) {
|
||||||
GEMMA_CALL_FUNC(ffw_gating_biases);
|
GEMMA_CALL_FUNC(ffw_gating_biases);
|
||||||
|
|
@ -332,6 +341,7 @@ struct ModelWeightsPtrs {
|
||||||
vit_img_pos_embedding("img_pos_emb", tensor_index),
|
vit_img_pos_embedding("img_pos_emb", tensor_index),
|
||||||
vit_img_head_bias("img_head_bias", tensor_index),
|
vit_img_head_bias("img_head_bias", tensor_index),
|
||||||
vit_img_head_kernel("img_head_kernel", tensor_index),
|
vit_img_head_kernel("img_head_kernel", tensor_index),
|
||||||
|
mm_embed_norm("mm_embed_norm", tensor_index),
|
||||||
scale_names(config.scale_names),
|
scale_names(config.scale_names),
|
||||||
weights_config(config) {
|
weights_config(config) {
|
||||||
c_layers.reserve(config.layer_configs.size());
|
c_layers.reserve(config.layer_configs.size());
|
||||||
|
|
@ -372,6 +382,8 @@ struct ModelWeightsPtrs {
|
||||||
MatPtrT<float> vit_img_head_bias;
|
MatPtrT<float> vit_img_head_bias;
|
||||||
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
|
||||||
|
|
||||||
|
MatPtrT<WeightF32OrBF16> mm_embed_norm;
|
||||||
|
|
||||||
std::unordered_set<std::string> scale_names;
|
std::unordered_set<std::string> scale_names;
|
||||||
|
|
||||||
const ModelConfig& weights_config;
|
const ModelConfig& weights_config;
|
||||||
|
|
@ -488,6 +500,7 @@ struct ModelWeightsPtrs {
|
||||||
GEMMA_CALL_FUNC(vit_img_pos_embedding);
|
GEMMA_CALL_FUNC(vit_img_pos_embedding);
|
||||||
GEMMA_CALL_FUNC(vit_img_head_bias);
|
GEMMA_CALL_FUNC(vit_img_head_bias);
|
||||||
GEMMA_CALL_FUNC(vit_img_head_kernel);
|
GEMMA_CALL_FUNC(vit_img_head_kernel);
|
||||||
|
GEMMA_CALL_FUNC(mm_embed_norm);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
|
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
|
||||||
|
|
|
||||||
|
|
@ -202,7 +202,7 @@ class MMStorage {
|
||||||
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
|
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
|
||||||
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
|
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
|
||||||
// per package and 512 MiB, respectively.
|
// per package and 512 MiB, respectively.
|
||||||
static constexpr size_t kMaxM = 2048;
|
static constexpr size_t kMaxM = 4096;
|
||||||
static constexpr size_t kMaxK = 64 * 1024;
|
static constexpr size_t kMaxK = 64 * 1024;
|
||||||
static constexpr size_t kMaxN = 256 * 1024;
|
static constexpr size_t kMaxN = 256 * 1024;
|
||||||
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
// Upper bound for per-worker B storage on the stack. Chosen such that one row
|
||||||
|
|
|
||||||
|
|
@ -802,6 +802,48 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
.prob = topk_logits[topk_sampled_index]};
|
.prob = topk_logits[topk_sampled_index]};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Performs 4x4 average pooling across row vectors
|
||||||
|
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
|
||||||
|
// Each output row is the average of a 4x4 block of input rows
|
||||||
|
template <typename T>
|
||||||
|
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
||||||
|
Extents2D extents = input.Extents();
|
||||||
|
// Input validation
|
||||||
|
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||||
|
// Create output with 256 rows and same number of columns
|
||||||
|
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
||||||
|
RowVectorBatch<T> result(Extents2D{out_rows, extents.cols});
|
||||||
|
const size_t input_dim = 64; // Input is 64×64
|
||||||
|
const size_t output_dim = 16; // Output is 16×16
|
||||||
|
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
|
||||||
|
for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) {
|
||||||
|
size_t out_idx = out_row_idx * output_dim + out_col_idx;
|
||||||
|
T* output_row = result.Batch(out_idx);
|
||||||
|
// Initialize output row to zeros
|
||||||
|
std::fill(output_row, output_row + extents.cols, 0);
|
||||||
|
// Average 16 row vectors from a 4x4 block
|
||||||
|
for (size_t i = 0; i < 4; ++i) {
|
||||||
|
for (size_t j = 0; j < 4; ++j) {
|
||||||
|
size_t in_row_idx = out_row_idx * 4 + i;
|
||||||
|
size_t in_col_idx = out_col_idx * 4 + j;
|
||||||
|
size_t in_idx = in_row_idx * input_dim + in_col_idx;
|
||||||
|
const T* input_row = input.Batch(in_idx);
|
||||||
|
// Add each input row to the output
|
||||||
|
// TODO(philculliton): use AddFrom in ops-inl for a vectorized loop.
|
||||||
|
for (size_t col = 0; col < extents.cols; ++col) {
|
||||||
|
output_row[col] += input_row[col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Divide by 16 to get the average
|
||||||
|
for (size_t col = 0; col < extents.cols; ++col) {
|
||||||
|
output_row[col] *= T{0.0625};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue