mirror of https://github.com/google/gemma.cpp.git
Fix PaliGemma's GenerateImageTokensT().
Move image related config values from LayerConfig to ModelConfig. Minor changes: Add a few comments, remove gcpp:: qualification where it wasn't needed in a few places, define local constants in VitAttention.DotSoftmaxWeightedSum() PiperOrigin-RevId: 687210519
This commit is contained in:
parent
0d68555f87
commit
c6384574db
|
|
@ -20,12 +20,13 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// TODO(janwas): merge with functions below.
|
||||
// Struct to bundle model information.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
ModelTraining training;
|
||||
|
|
@ -42,13 +43,13 @@ const char* ParseType(const std::string& type_string, Type& type);
|
|||
const char* ModelString(Model model, ModelTraining training);
|
||||
const char* StringFromType(Type type);
|
||||
|
||||
// Wraps the given prompt using the expected control tokens for IT models.
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
//
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
float EmbeddingScaling(size_t model_dim);
|
||||
|
||||
// Returns the scale value to use for the query in the attention computation.
|
||||
float ChooseQueryScale(const ModelConfig& config);
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ static ModelConfig ConfigGemma2_27B() {
|
|||
config.model_name = "Gemma2_27B";
|
||||
config.model = Model::GEMMA2_27B;
|
||||
config.model_dim = 4608;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
|
||||
|
|
@ -61,7 +61,7 @@ static ModelConfig ConfigGemma2_9B() {
|
|||
config.model_name = "Gemma2_9B";
|
||||
config.model = Model::GEMMA2_9B;
|
||||
config.model_dim = 3584;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
|
||||
|
|
@ -82,7 +82,7 @@ static ModelConfig ConfigGemma2_2B() {
|
|||
config.model_name = "Gemma2_2B";
|
||||
config.model = Model::GEMMA2_2B;
|
||||
config.model_dim = 2304;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
LayerConfig layer_config = {.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
|
||||
|
|
@ -103,8 +103,8 @@ static ModelConfig ConfigGemma7B() {
|
|||
config.model_name = "Gemma7B";
|
||||
config.model = Model::GEMMA_7B;
|
||||
config.model_dim = 3072;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = gcpp::kSeqLen;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
|
||||
|
|
@ -115,7 +115,7 @@ static ModelConfig ConfigGemma7B() {
|
|||
config.layer_configs = {28, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen);
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -124,8 +124,8 @@ static ModelConfig ConfigGemma2B() {
|
|||
config.model_name = "Gemma2B";
|
||||
config.model = Model::GEMMA_2B;
|
||||
config.model_dim = 2048;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.seq_len = gcpp::kSeqLen;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = kSeqLen;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
|
||||
|
|
@ -135,7 +135,7 @@ static ModelConfig ConfigGemma2B() {
|
|||
};
|
||||
config.layer_configs = {18, layer_config};
|
||||
config.num_tensor_scales = 4 * config.layer_configs.size();
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen);
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ static ModelConfig ConfigGriffin2B() {
|
|||
// Griffin uses local attention, so kSeqLen is actually the local attention
|
||||
// window.
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = gcpp::kVocabSize;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 2048;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.model_dim,
|
||||
|
|
@ -204,22 +204,34 @@ static ModelConfig ConfigPaliGemma_224() {
|
|||
config.model = Model::PALIGEMMA_224;
|
||||
config.vit_model_dim = 1152;
|
||||
config.vocab_size = 256000 + 1024 + 128; // = 257152
|
||||
config.vit_seq_len = 16 * 16;
|
||||
config.image_size = 224;
|
||||
config.patch_width = 14;
|
||||
const size_t num_patches = config.image_size / config.patch_width;
|
||||
config.vit_seq_len = num_patches * num_patches;
|
||||
LayerConfig layer_config = {
|
||||
.model_dim = config.vit_model_dim,
|
||||
.ff_hidden_dim = 4304,
|
||||
.heads = 16,
|
||||
.kv_heads = 16,
|
||||
.qkv_dim = 72,
|
||||
.ff_biases = true,
|
||||
.type = LayerAttentionType::kVit,
|
||||
.patch_width = 14,
|
||||
.image_size = 224,
|
||||
};
|
||||
config.vit_layer_configs = {27, layer_config};
|
||||
config.num_vit_scales = 4 * config.vit_layer_configs.size();
|
||||
return config;
|
||||
}
|
||||
|
||||
ModelConfig VitConfig(const ModelConfig& config) {
|
||||
ModelConfig vit_config = ConfigNoSSM();
|
||||
vit_config.model_dim = config.vit_model_dim;
|
||||
vit_config.seq_len = config.vit_seq_len;
|
||||
vit_config.layer_configs = config.vit_layer_configs;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
vit_config.vocab_size = 0;
|
||||
return vit_config;
|
||||
}
|
||||
|
||||
ModelConfig ConfigFromModel(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
|
|
|
|||
|
|
@ -131,9 +131,6 @@ struct LayerConfig {
|
|||
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||
ActivationType activation = ActivationType::Gelu;
|
||||
PostQKType post_qk = PostQKType::Rope;
|
||||
// Dimensions related to image processing.
|
||||
int patch_width = 14;
|
||||
int image_size = 224;
|
||||
};
|
||||
|
||||
struct ModelConfig {
|
||||
|
|
@ -185,11 +182,17 @@ struct ModelConfig {
|
|||
std::unordered_set<std::string> scale_names;
|
||||
int norm_num_groups = 1;
|
||||
int model_family_version = 1;
|
||||
// Dimensions related to image processing.
|
||||
int patch_width = 14;
|
||||
int image_size = 224;
|
||||
};
|
||||
|
||||
// Returns the config for the given model.
|
||||
ModelConfig ConfigFromModel(Model model);
|
||||
|
||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||
ModelConfig VitConfig(const ModelConfig& config);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||
|
|
|
|||
|
|
@ -615,43 +615,39 @@ class VitAttention {
|
|||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
const float query_scale =
|
||||
1.0f / sqrtf(static_cast<float>(layer_config_.qkv_dim));
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len = activations_.seq_len;
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
// A "head group" in the context of GQA refers to a collection of query
|
||||
// heads that share the same key and value heads.
|
||||
HWY_ASSERT_M(layer_config_.heads == layer_config_.kv_heads,
|
||||
"Vit expects MHA");
|
||||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(
|
||||
0, layer_config_.heads * num_tokens_,
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(token) + head * 3 * layer_config_.qkv_dim;
|
||||
MulByConst(query_scale, q, layer_config_.qkv_dim);
|
||||
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(token) + head * activations_.seq_len;
|
||||
for (size_t i = 0; i < activations_.seq_len; ++i) {
|
||||
float* HWY_RESTRICT k = activations_.q.Batch(i) +
|
||||
head * 3 * layer_config_.qkv_dim +
|
||||
layer_config_.qkv_dim;
|
||||
head_att[i] = Dot(q, k, layer_config_.qkv_dim); // score = q.k
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, activations_.seq_len);
|
||||
Softmax(head_att, seq_len);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(token) + head * layer_config_.qkv_dim;
|
||||
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < activations_.seq_len; ++i) {
|
||||
activations_.att_out.Batch(token) + head * qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.q.Batch(i) +
|
||||
head * 3 * layer_config_.qkv_dim +
|
||||
2 * layer_config_.qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, layer_config_.qkv_dim);
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -965,6 +961,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
layer_weights->vit.layer_norm_0_scale.data_scale1(),
|
||||
layer_weights->vit.layer_norm_0_bias.data_scale1(),
|
||||
activations.pre_att_rms_out.All(), model_dim);
|
||||
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
// y ~ att_sums
|
||||
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
|
||||
|
|
@ -1104,8 +1101,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
const ModelWeightsPtrs<T>& weights,
|
||||
Activations& activations) {
|
||||
const size_t model_dim = weights.weights_config.vit_model_dim;
|
||||
const size_t patch_width =
|
||||
weights.weights_config.vit_layer_configs[0].patch_width;
|
||||
const size_t patch_width = weights.weights_config.patch_width;
|
||||
const size_t seq_len = weights.weights_config.vit_seq_len;
|
||||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
||||
|
|
@ -1483,18 +1479,17 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
|
|||
const Image& image, ImageTokens& image_tokens,
|
||||
PerClusterPools& pools) {
|
||||
if (model.Config().vit_layer_configs.empty()) {
|
||||
return;
|
||||
} else {
|
||||
Activations prefill_activations(model.Config());
|
||||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
prefill_runtime_config.prefill_tbatch_size = model.Config().vit_seq_len;
|
||||
prefill_activations.Allocate(prefill_runtime_config.prefill_tbatch_size,
|
||||
pools);
|
||||
ModelConfig vit_config = VitConfig(model.Config());
|
||||
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
|
||||
Activations prefill_activations(vit_config);
|
||||
prefill_activations.Allocate(vit_config.seq_len, pools);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
|
||||
image_tokens, prefill_activations);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -349,7 +349,9 @@ struct ModelWeightsPtrs {
|
|||
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
|
||||
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
|
||||
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
|
||||
vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3,
|
||||
vit_img_embedding_kernel(
|
||||
"img_emb_kernel",
|
||||
config.patch_width * config.patch_width * 3,
|
||||
config.vit_model_dim),
|
||||
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
|
||||
vit_img_head_bias("img_head_bias", 1, config.model_dim),
|
||||
|
|
|
|||
Loading…
Reference in New Issue