Fix PaliGemma's GenerateImageTokensT().

Move image related config values from LayerConfig to ModelConfig.
Minor changes: Add a few comments, remove gcpp:: qualification where it wasn't needed in a few places, define local constants in VitAttention.DotSoftmaxWeightedSum()

PiperOrigin-RevId: 687210519
This commit is contained in:
Daniel Keysers 2024-10-18 01:33:40 -07:00 committed by Copybara-Service
parent 0d68555f87
commit c6384574db
5 changed files with 83 additions and 70 deletions

View File

@ -20,12 +20,13 @@
#include <string> #include <string>
#include "compression/shared.h" // ModelTraining
#include "gemma/configs.h" // IWYU pragma: export #include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo #include "hwy/base.h" // ConvertScalarTo
namespace gcpp { namespace gcpp {
// TODO(janwas): merge with functions below. // Struct to bundle model information.
struct ModelInfo { struct ModelInfo {
Model model; Model model;
ModelTraining training; ModelTraining training;
@ -42,13 +43,13 @@ const char* ParseType(const std::string& type_string, Type& type);
const char* ModelString(Model model, ModelTraining training); const char* ModelString(Model model, ModelTraining training);
const char* StringFromType(Type type); const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// ---------------------------------------------------------------------------- // Returns the scale value to use for the embedding (basically sqrt model_dim).
//
float EmbeddingScaling(size_t model_dim); float EmbeddingScaling(size_t model_dim);
// Returns the scale value to use for the query in the attention computation.
float ChooseQueryScale(const ModelConfig& config); float ChooseQueryScale(const ModelConfig& config);
} // namespace gcpp } // namespace gcpp

View File

@ -40,7 +40,7 @@ static ModelConfig ConfigGemma2_27B() {
config.model_name = "Gemma2_27B"; config.model_name = "Gemma2_27B";
config.model = Model::GEMMA2_27B; config.model = Model::GEMMA2_27B;
config.model_dim = 4608; config.model_dim = 4608;
config.vocab_size = gcpp::kVocabSize; config.vocab_size = kVocabSize;
config.seq_len = 8192; config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim, LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 16 * 4608 / 2, // = 36864 .ff_hidden_dim = 16 * 4608 / 2, // = 36864
@ -61,7 +61,7 @@ static ModelConfig ConfigGemma2_9B() {
config.model_name = "Gemma2_9B"; config.model_name = "Gemma2_9B";
config.model = Model::GEMMA2_9B; config.model = Model::GEMMA2_9B;
config.model_dim = 3584; config.model_dim = 3584;
config.vocab_size = gcpp::kVocabSize; config.vocab_size = kVocabSize;
config.seq_len = 8192; config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim, LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 3584 / 2, // = 14336 .ff_hidden_dim = 8 * 3584 / 2, // = 14336
@ -82,7 +82,7 @@ static ModelConfig ConfigGemma2_2B() {
config.model_name = "Gemma2_2B"; config.model_name = "Gemma2_2B";
config.model = Model::GEMMA2_2B; config.model = Model::GEMMA2_2B;
config.model_dim = 2304; config.model_dim = 2304;
config.vocab_size = gcpp::kVocabSize; config.vocab_size = kVocabSize;
config.seq_len = 8192; config.seq_len = 8192;
LayerConfig layer_config = {.model_dim = config.model_dim, LayerConfig layer_config = {.model_dim = config.model_dim,
.ff_hidden_dim = 8 * 2304 / 2, // = 9216 .ff_hidden_dim = 8 * 2304 / 2, // = 9216
@ -103,8 +103,8 @@ static ModelConfig ConfigGemma7B() {
config.model_name = "Gemma7B"; config.model_name = "Gemma7B";
config.model = Model::GEMMA_7B; config.model = Model::GEMMA_7B;
config.model_dim = 3072; config.model_dim = 3072;
config.vocab_size = gcpp::kVocabSize; config.vocab_size = kVocabSize;
config.seq_len = gcpp::kSeqLen; config.seq_len = kSeqLen;
LayerConfig layer_config = { LayerConfig layer_config = {
.model_dim = config.model_dim, .model_dim = config.model_dim,
.ff_hidden_dim = 16 * 3072 / 2, // = 24576 .ff_hidden_dim = 16 * 3072 / 2, // = 24576
@ -115,7 +115,7 @@ static ModelConfig ConfigGemma7B() {
config.layer_configs = {28, layer_config}; config.layer_configs = {28, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size(); config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize; config.query_scale = QueryScaleType::SqrtKeySize;
config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen); config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen);
return config; return config;
} }
@ -124,8 +124,8 @@ static ModelConfig ConfigGemma2B() {
config.model_name = "Gemma2B"; config.model_name = "Gemma2B";
config.model = Model::GEMMA_2B; config.model = Model::GEMMA_2B;
config.model_dim = 2048; config.model_dim = 2048;
config.vocab_size = gcpp::kVocabSize; config.vocab_size = kVocabSize;
config.seq_len = gcpp::kSeqLen; config.seq_len = kSeqLen;
LayerConfig layer_config = { LayerConfig layer_config = {
.model_dim = config.model_dim, .model_dim = config.model_dim,
.ff_hidden_dim = 16 * 2048 / 2, // = 16384 .ff_hidden_dim = 16 * 2048 / 2, // = 16384
@ -135,7 +135,7 @@ static ModelConfig ConfigGemma2B() {
}; };
config.layer_configs = {18, layer_config}; config.layer_configs = {18, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size(); config.num_tensor_scales = 4 * config.layer_configs.size();
config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen); config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen);
return config; return config;
} }
@ -169,7 +169,7 @@ static ModelConfig ConfigGriffin2B() {
// Griffin uses local attention, so kSeqLen is actually the local attention // Griffin uses local attention, so kSeqLen is actually the local attention
// window. // window.
config.model_dim = 2560; config.model_dim = 2560;
config.vocab_size = gcpp::kVocabSize; config.vocab_size = kVocabSize;
config.seq_len = 2048; config.seq_len = 2048;
LayerConfig layer_config = { LayerConfig layer_config = {
.model_dim = config.model_dim, .model_dim = config.model_dim,
@ -204,22 +204,34 @@ static ModelConfig ConfigPaliGemma_224() {
config.model = Model::PALIGEMMA_224; config.model = Model::PALIGEMMA_224;
config.vit_model_dim = 1152; config.vit_model_dim = 1152;
config.vocab_size = 256000 + 1024 + 128; // = 257152 config.vocab_size = 256000 + 1024 + 128; // = 257152
config.vit_seq_len = 16 * 16; config.image_size = 224;
config.patch_width = 14;
const size_t num_patches = config.image_size / config.patch_width;
config.vit_seq_len = num_patches * num_patches;
LayerConfig layer_config = { LayerConfig layer_config = {
.model_dim = config.vit_model_dim, .model_dim = config.vit_model_dim,
.ff_hidden_dim = 4304, .ff_hidden_dim = 4304,
.heads = 16, .heads = 16,
.kv_heads = 16, .kv_heads = 16,
.qkv_dim = 72, .qkv_dim = 72,
.ff_biases = true,
.type = LayerAttentionType::kVit, .type = LayerAttentionType::kVit,
.patch_width = 14,
.image_size = 224,
}; };
config.vit_layer_configs = {27, layer_config}; config.vit_layer_configs = {27, layer_config};
config.num_vit_scales = 4 * config.vit_layer_configs.size(); config.num_vit_scales = 4 * config.vit_layer_configs.size();
return config; return config;
} }
ModelConfig VitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_model_dim;
vit_config.seq_len = config.vit_seq_len;
vit_config.layer_configs = config.vit_layer_configs;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;
}
ModelConfig ConfigFromModel(Model model) { ModelConfig ConfigFromModel(Model model) {
switch (model) { switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:

View File

@ -131,9 +131,6 @@ struct LayerConfig {
LayerAttentionType type = LayerAttentionType::kGemma; LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu; ActivationType activation = ActivationType::Gelu;
PostQKType post_qk = PostQKType::Rope; PostQKType post_qk = PostQKType::Rope;
// Dimensions related to image processing.
int patch_width = 14;
int image_size = 224;
}; };
struct ModelConfig { struct ModelConfig {
@ -185,11 +182,17 @@ struct ModelConfig {
std::unordered_set<std::string> scale_names; std::unordered_set<std::string> scale_names;
int norm_num_groups = 1; int norm_num_groups = 1;
int model_family_version = 1; int model_family_version = 1;
// Dimensions related to image processing.
int patch_width = 14;
int image_size = 224;
}; };
// Returns the config for the given model. // Returns the config for the given model.
ModelConfig ConfigFromModel(Model model); ModelConfig ConfigFromModel(Model model);
// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_

View File

@ -615,45 +615,41 @@ class VitAttention {
} }
HWY_NOINLINE void DotSoftmaxWeightedSum() { HWY_NOINLINE void DotSoftmaxWeightedSum() {
const float query_scale = const size_t qkv_dim = layer_config_.qkv_dim;
1.0f / sqrtf(static_cast<float>(layer_config_.qkv_dim)); const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// A "head group" in the context of GQA refers to a collection of query
// heads that share the same key and value heads.
HWY_ASSERT_M(layer_config_.heads == layer_config_.kv_heads,
"Vit expects MHA");
// Compute Q.K, softmax, and weighted V. // Compute Q.K, softmax, and weighted V.
pool_.Run( pool_.Run(0, layer_config_.heads * num_tokens_,
0, layer_config_.heads * num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
[&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t head = task % layer_config_.heads;
const size_t head = task % layer_config_.heads; const size_t token = task / layer_config_.heads;
const size_t token = task / layer_config_.heads; // Compute Q.K scores, which are "logits" stored in head_att.
// Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q =
float* HWY_RESTRICT q = activations_.q.Batch(token) + head * 3 * qkv_dim;
activations_.q.Batch(token) + head * 3 * layer_config_.qkv_dim; MulByConst(query_scale, q, qkv_dim);
MulByConst(query_scale, q, layer_config_.qkv_dim); float* HWY_RESTRICT head_att =
float* HWY_RESTRICT head_att = activations_.att.Batch(token) + head * activations_.seq_len;
activations_.att.Batch(token) + head * activations_.seq_len; for (size_t i = 0; i < seq_len; ++i) {
for (size_t i = 0; i < activations_.seq_len; ++i) { float* HWY_RESTRICT k =
float* HWY_RESTRICT k = activations_.q.Batch(i) + activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
head * 3 * layer_config_.qkv_dim + head_att[i] = Dot(q, k, qkv_dim); // score = q.k
layer_config_.qkv_dim; }
head_att[i] = Dot(q, k, layer_config_.qkv_dim); // score = q.k // SoftMax yields "probabilities" in head_att.
} Softmax(head_att, seq_len);
// SoftMax yields "probabilities" in head_att. // Compute weighted sum of v into att_out.
Softmax(head_att, activations_.seq_len); float* HWY_RESTRICT att_out =
// Compute weighted sum of v into att_out. activations_.att_out.Batch(token) + head * qkv_dim;
float* HWY_RESTRICT att_out = hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
activations_.att_out.Batch(token) + head * layer_config_.qkv_dim; for (size_t i = 0; i < seq_len; ++i) {
hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out)); float* HWY_RESTRICT v = activations_.q.Batch(i) +
for (size_t i = 0; i < activations_.seq_len; ++i) { head * 3 * qkv_dim + 2 * qkv_dim;
float* HWY_RESTRICT v = activations_.q.Batch(i) + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
head * 3 * layer_config_.qkv_dim + }
2 * layer_config_.qkv_dim; });
MulByConstAndAdd(head_att[i], v, att_out, layer_config_.qkv_dim);
}
});
} }
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
@ -965,6 +961,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
layer_weights->vit.layer_norm_0_scale.data_scale1(), layer_weights->vit.layer_norm_0_scale.data_scale1(),
layer_weights->vit.layer_norm_0_bias.data_scale1(), layer_weights->vit.layer_norm_0_bias.data_scale1(),
activations.pre_att_rms_out.All(), model_dim); activations.pre_att_rms_out.All(), model_dim);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums // y ~ att_sums
VitAttention<T>(num_tokens, layer, activations, layer_weights)(); VitAttention<T>(num_tokens, layer, activations, layer_weights)();
@ -1104,8 +1101,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
Activations& activations) { Activations& activations) {
const size_t model_dim = weights.weights_config.vit_model_dim; const size_t model_dim = weights.weights_config.vit_model_dim;
const size_t patch_width = const size_t patch_width = weights.weights_config.patch_width;
weights.weights_config.vit_layer_configs[0].patch_width;
const size_t seq_len = weights.weights_config.vit_seq_len; const size_t seq_len = weights.weights_config.vit_seq_len;
const size_t patch_size = patch_width * patch_width * 3; const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
@ -1483,17 +1479,16 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
const Image& image, ImageTokens& image_tokens, const Image& image, ImageTokens& image_tokens,
PerClusterPools& pools) { PerClusterPools& pools) {
if (model.Config().vit_layer_configs.empty()) { if (model.Config().vit_layer_configs.empty()) {
return; HWY_ABORT("Model does not support generating image tokens.");
} else {
Activations prefill_activations(model.Config());
RuntimeConfig prefill_runtime_config = runtime_config;
prefill_runtime_config.prefill_tbatch_size = model.Config().vit_seq_len;
prefill_activations.Allocate(prefill_runtime_config.prefill_tbatch_size,
pools);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
image_tokens, prefill_activations);
} }
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = VitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, pools);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(*model.GetWeightsOfType<T>(), prefill_runtime_config, image,
image_tokens, prefill_activations);
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE

View File

@ -349,8 +349,10 @@ struct ModelWeightsPtrs {
vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim),
vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim),
vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim),
vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3, vit_img_embedding_kernel(
config.vit_model_dim), "img_emb_kernel",
config.patch_width * config.patch_width * 3,
config.vit_model_dim),
vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim),
vit_img_head_bias("img_head_bias", 1, config.model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim),
vit_img_head_kernel("img_head_kernel", config.vit_model_dim, vit_img_head_kernel("img_head_kernel", config.vit_model_dim,