Fix PaliGemma's GenerateImageTokensT().

Move image related config values from LayerConfig to ModelConfig.
Minor changes: Add a few comments, remove gcpp:: qualification where it wasn't needed in a few places, define local constants in VitAttention.DotSoftmaxWeightedSum()

PiperOrigin-RevId: 687210519
This commit is contained in:
Daniel Keysers 2024-10-18 01:33:40 -07:00 committed by Copybara-Service
parent 0d68555f87
commit c6384574db
5 changed files with 83 additions and 70 deletions

View File

@ -20,12 +20,13 @@
#include <string>
#include "compression/shared.h" // ModelTraining
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo
namespace gcpp {
// TODO(janwas): merge with functions below.
// Struct to bundle model information.
struct ModelInfo {
Model model;
ModelTraining training;
@ -42,13 +43,13 @@ const char* ParseType(const std::string& type_string, Type& type);
const char* ModelString(Model model, ModelTraining training);
const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// ----------------------------------------------------------------------------
//
// Returns the scale value to use for the embedding (basically sqrt model_dim).
float EmbeddingScaling(size_t model_dim);
// Returns the scale value to use for the query in the attention computation.
float ChooseQueryScale(const ModelConfig& config);
} // namespace gcpp

View File

@ -40,7 +40,7 @@ static ModelConfig ConfigGemma2_27B() {
config.model_name = "Gemma2_27B";
config.model = Model::GEMMA2_27B;
config.model_dim = 4608;
config.vocab_size = gcpp::kVocabSize;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 4608 / 2, // = 36864
@ -61,7 +61,7 @@ static ModelConfig ConfigGemma2_9B() {
config.model_name = "Gemma2_9B";
config.model = Model::GEMMA2_9B;
config.model_dim = 3584;
config.vocab_size = gcpp::kVocabSize;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 3584 / 2, // = 14336
@ -82,7 +82,7 @@ static ModelConfig ConfigGemma2_2B() {
config.model_name = "Gemma2_2B";
config.model = Model::GEMMA2_2B;
config.model_dim = 2304;
config.vocab_size = gcpp::kVocabSize;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 2304 / 2, // = 9216
@ -103,8 +103,8 @@ static ModelConfig ConfigGemma7B() {
config.model_name = "Gemma7B";
config.model = Model::GEMMA_7B;
config.model_dim = 3072;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = gcpp::kSeqLen;
config.vocab_size = kVocabSize;
config.seq_len = kSeqLen;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 3072 / 2, // = 24576
@ -115,7 +115,7 @@ static ModelConfig ConfigGemma7B() {
config.layer_configs = {28, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen);
config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
return config;
}
@ -124,8 +124,8 @@ static ModelConfig ConfigGemma2B() {
config.model_name = "Gemma2B";
config.model = Model::GEMMA_2B;
config.model_dim = 2048;
config.vocab_size = gcpp::kVocabSize;
config.seq_len = gcpp::kSeqLen;
config.vocab_size = kVocabSize;
config.seq_len = kSeqLen;
LayerConfig layer_config = {
.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 2048 / 2, // = 16384
@ -135,7 +135,7 @@ static ModelConfig ConfigGemma2B() {
};
config.layer_configs = {18, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen);
config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
return config;
}
@ -169,7 +169,7 @@ static ModelConfig ConfigGriffin2B() {
// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
config.model_dim = 2560;
config.vocab_size = gcpp::kVocabSize;
config.vocab_size = kVocabSize;
config.seq_len = 2048;
LayerConfig layer_config = {
.model_dim = config.model_dim,
@ -204,22 +204,34 @@ static ModelConfig ConfigPaliGemma_224() {
config.model = Model::PALIGEMMA_224;
config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152
config.vit_seq_len = 16 * 16;
config.image_size = 224;
config.patch_width = 14;
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig layer_config = {
.model_dim = config.vit_model_dim,
.ff_hidden_dim = 4304,
.heads = 16,
.kv_heads = 16,
.qkv_dim = 72,
.ff_biases = true,
.type = LayerAttentionType::kVit,
.patch_width = 14,
.image_size = 224,
};
config.vit_layer_configs = {27, layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size();
return config;
}
ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
vit_config.seq_len = config.vit_seq_len;
vit_config.layer_configs = config.vit_layer_configs;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;
}
ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA_2B:

View File

@ -131,9 +131,6 @@ struct LayerConfig {
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;
PostQKType post_qk = PostQKType::Rope;
// Dimensions related to image processing.
int patch_width = 14;
int image_size = 224;
};
struct ModelConfig {
@ -185,11 +182,17 @@ struct ModelConfig {
std::unordered_set<std::string> scale_names;
int norm_num_groups = 1;
int model_family_version = 1;
// Dimensions related to image processing.
int patch_width = 14;
int image_size = 224;
};
// Returns the config for the given model.
ModelConfig ConfigFromModel(Model model);
// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

View File

@ -615,45 +615,41 @@ class VitAttention {
}
HWY_NOINLINE void DotSoftmaxWeightedSum() {
const float query_scale =
1.0f / sqrtf(static_cast<float>(layer_config_.qkv_dim));
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
HWY_ASSERT_M(layer_config_.heads == layer_config_.kv_heads,
"Vit expects MHA");
// Compute Q.K, softmax, and weighted V.
pool_.Run(
0, layer_config_.heads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * layer_config_.qkv_dim;
MulByConst(query_scale, q, layer_config_.qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * activations_.seq_len;
for (size_t i = 0; i < activations_.seq_len; ++i) {
float* HWY_RESTRICT k = activations_.q.Batch(i) +
head * 3 * layer_config_.qkv_dim +
layer_config_.qkv_dim;
head_att[i] = Dot(q, k, layer_config_.qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, activations_.seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * layer_config_.qkv_dim;
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < activations_.seq_len; ++i) {
float* HWY_RESTRICT v = activations_.q.Batch(i) +
head * 3 * layer_config_.qkv_dim +
2 * layer_config_.qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, layer_config_.qkv_dim);
}
});
pool_.Run(0, layer_config_.heads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * activations_.seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.q.Batch(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
}
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
@ -965,6 +961,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
layer_weights->vit.layer_norm_0_scale.data_scale1(),
layer_weights->vit.layer_norm_0_bias.data_scale1(),
activations.pre_att_rms_out.All(), model_dim);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
@ -1104,8 +1101,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights,
Activations& activations) {
const size_t model_dim = weights.weights_config.vit_model_dim;
const size_t patch_width =
weights.weights_config.vit_layer_configs[0].patch_width;
const size_t patch_width = weights.weights_config.patch_width;
const size_t seq_len = weights.weights_config.vit_seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
@ -1483,17 +1479,16 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
const Image& image, ImageTokens& image_tokens,
PerClusterPools& pools) {
if (model.Config().vit_layer_configs.empty()) {
return;
} else {
Activations prefill_activations(model.Config());
RuntimeConfig prefill_runtime_config = runtime_config;
prefill_runtime_config.prefill_tbatch_size = model.Config().vit_seq_len;
prefill_activations.Allocate(prefill_runtime_config.prefill_tbatch_size,
pools);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
image_tokens, prefill_activations);
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = VitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, pools);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
image_tokens, prefill_activations);
}
} // namespace HWY_NAMESPACE

View File

@ -349,8 +349,10 @@ struct ModelWeightsPtrs {
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3,
config.vit_model_dim),
vit_img_embedding_kernel(
"img_emb_kernel",
config.patch_width * config.patch_width * 3,
config.vit_model_dim),
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.vit_model_dim,