diff --git a/gemma/common.h b/gemma/common.h index e933e8d..6b5539a 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -20,12 +20,13 @@ #include +#include "compression/shared.h" // ModelTraining #include "gemma/configs.h" // IWYU pragma: export #include "hwy/base.h" // ConvertScalarTo namespace gcpp { -// TODO(janwas): merge with functions below. +// Struct to bundle model information. struct ModelInfo { Model model; ModelTraining training; @@ -42,13 +43,13 @@ const char* ParseType(const std::string& type_string, Type& type); const char* ModelString(Model model, ModelTraining training); const char* StringFromType(Type type); +// Wraps the given prompt using the expected control tokens for IT models. void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); -// ---------------------------------------------------------------------------- -// - +// Returns the scale value to use for the embedding (basically sqrt model_dim). float EmbeddingScaling(size_t model_dim); +// Returns the scale value to use for the query in the attention computation. float ChooseQueryScale(const ModelConfig& config); } // namespace gcpp diff --git a/gemma/configs.cc b/gemma/configs.cc index bc4eee6..f72fe11 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -40,7 +40,7 @@ static ModelConfig ConfigGemma2_27B() { config.model_name = "Gemma2_27B"; config.model = Model::GEMMA2_27B; config.model_dim = 4608; - config.vocab_size = gcpp::kVocabSize; + config.vocab_size = kVocabSize; config.seq_len = 8192; LayerConfig layer_config = {.model_dim = config.model_dim, .ff_hidden_dim = 16 * 4608 / 2, // = 36864 @@ -61,7 +61,7 @@ static ModelConfig ConfigGemma2_9B() { config.model_name = "Gemma2_9B"; config.model = Model::GEMMA2_9B; config.model_dim = 3584; - config.vocab_size = gcpp::kVocabSize; + config.vocab_size = kVocabSize; config.seq_len = 8192; LayerConfig layer_config = {.model_dim = config.model_dim, .ff_hidden_dim = 8 * 3584 / 2, // = 14336 @@ -82,7 +82,7 @@ static ModelConfig ConfigGemma2_2B() { config.model_name = "Gemma2_2B"; config.model = Model::GEMMA2_2B; config.model_dim = 2304; - config.vocab_size = gcpp::kVocabSize; + config.vocab_size = kVocabSize; config.seq_len = 8192; LayerConfig layer_config = {.model_dim = config.model_dim, .ff_hidden_dim = 8 * 2304 / 2, // = 9216 @@ -103,8 +103,8 @@ static ModelConfig ConfigGemma7B() { config.model_name = "Gemma7B"; config.model = Model::GEMMA_7B; config.model_dim = 3072; - config.vocab_size = gcpp::kVocabSize; - config.seq_len = gcpp::kSeqLen; + config.vocab_size = kVocabSize; + config.seq_len = kSeqLen; LayerConfig layer_config = { .model_dim = config.model_dim, .ff_hidden_dim = 16 * 3072 / 2, // = 24576 @@ -115,7 +115,7 @@ static ModelConfig ConfigGemma7B() { config.layer_configs = {28, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); config.query_scale = QueryScaleType::SqrtKeySize; - config.attention_window_sizes = FixedAttentionWindowSizes<28>(gcpp::kSeqLen); + config.attention_window_sizes = FixedAttentionWindowSizes<28>(kSeqLen); return config; } @@ -124,8 +124,8 @@ static ModelConfig ConfigGemma2B() { config.model_name = "Gemma2B"; config.model = Model::GEMMA_2B; config.model_dim = 2048; - config.vocab_size = gcpp::kVocabSize; - config.seq_len = gcpp::kSeqLen; + config.vocab_size = kVocabSize; + config.seq_len = kSeqLen; LayerConfig layer_config = { .model_dim = config.model_dim, .ff_hidden_dim = 16 * 2048 / 2, // = 16384 @@ -135,7 +135,7 @@ static ModelConfig ConfigGemma2B() { }; config.layer_configs = {18, layer_config}; config.num_tensor_scales = 4 * config.layer_configs.size(); - config.attention_window_sizes = FixedAttentionWindowSizes<18>(gcpp::kSeqLen); + config.attention_window_sizes = FixedAttentionWindowSizes<18>(kSeqLen); return config; } @@ -169,7 +169,7 @@ static ModelConfig ConfigGriffin2B() { // Griffin uses local attention, so kSeqLen is actually the local attention // window. config.model_dim = 2560; - config.vocab_size = gcpp::kVocabSize; + config.vocab_size = kVocabSize; config.seq_len = 2048; LayerConfig layer_config = { .model_dim = config.model_dim, @@ -204,22 +204,34 @@ static ModelConfig ConfigPaliGemma_224() { config.model = Model::PALIGEMMA_224; config.vit_model_dim = 1152; config.vocab_size = 256000 + 1024 + 128; // = 257152 - config.vit_seq_len = 16 * 16; + config.image_size = 224; + config.patch_width = 14; + const size_t num_patches = config.image_size / config.patch_width; + config.vit_seq_len = num_patches * num_patches; LayerConfig layer_config = { .model_dim = config.vit_model_dim, .ff_hidden_dim = 4304, .heads = 16, .kv_heads = 16, .qkv_dim = 72, + .ff_biases = true, .type = LayerAttentionType::kVit, - .patch_width = 14, - .image_size = 224, }; config.vit_layer_configs = {27, layer_config}; config.num_vit_scales = 4 * config.vit_layer_configs.size(); return config; } +ModelConfig VitConfig(const ModelConfig& config) { + ModelConfig vit_config = ConfigNoSSM(); + vit_config.model_dim = config.vit_model_dim; + vit_config.seq_len = config.vit_seq_len; + vit_config.layer_configs = config.vit_layer_configs; + // The Vit part does not have a vocabulary, the image patches are embedded. + vit_config.vocab_size = 0; + return vit_config; +} + ModelConfig ConfigFromModel(Model model) { switch (model) { case Model::GEMMA_2B: diff --git a/gemma/configs.h b/gemma/configs.h index ac82ab4..f7c6ac2 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -131,9 +131,6 @@ struct LayerConfig { LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; PostQKType post_qk = PostQKType::Rope; - // Dimensions related to image processing. - int patch_width = 14; - int image_size = 224; }; struct ModelConfig { @@ -185,11 +182,17 @@ struct ModelConfig { std::unordered_set scale_names; int norm_num_groups = 1; int model_family_version = 1; + // Dimensions related to image processing. + int patch_width = 14; + int image_size = 224; }; // Returns the config for the given model. ModelConfig ConfigFromModel(Model model); +// Returns the sub-config for the ViT model of the PaliGemma model. +ModelConfig VitConfig(const ModelConfig& config); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ce72a04..b34916e 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -615,45 +615,41 @@ class VitAttention { } HWY_NOINLINE void DotSoftmaxWeightedSum() { - const float query_scale = - 1.0f / sqrtf(static_cast(layer_config_.qkv_dim)); + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); + const size_t seq_len = activations_.seq_len; + const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - // A "head group" in the context of GQA refers to a collection of query - // heads that share the same key and value heads. - HWY_ASSERT_M(layer_config_.heads == layer_config_.kv_heads, - "Vit expects MHA"); // Compute Q.K, softmax, and weighted V. - pool_.Run( - 0, layer_config_.heads * num_tokens_, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t token = task / layer_config_.heads; - // Compute Q.K scores, which are "logits" stored in head_att. - float* HWY_RESTRICT q = - activations_.q.Batch(token) + head * 3 * layer_config_.qkv_dim; - MulByConst(query_scale, q, layer_config_.qkv_dim); - float* HWY_RESTRICT head_att = - activations_.att.Batch(token) + head * activations_.seq_len; - for (size_t i = 0; i < activations_.seq_len; ++i) { - float* HWY_RESTRICT k = activations_.q.Batch(i) + - head * 3 * layer_config_.qkv_dim + - layer_config_.qkv_dim; - head_att[i] = Dot(q, k, layer_config_.qkv_dim); // score = q.k - } - // SoftMax yields "probabilities" in head_att. - Softmax(head_att, activations_.seq_len); - // Compute weighted sum of v into att_out. - float* HWY_RESTRICT att_out = - activations_.att_out.Batch(token) + head * layer_config_.qkv_dim; - hwy::ZeroBytes(att_out, layer_config_.qkv_dim * sizeof(*att_out)); - for (size_t i = 0; i < activations_.seq_len; ++i) { - float* HWY_RESTRICT v = activations_.q.Batch(i) + - head * 3 * layer_config_.qkv_dim + - 2 * layer_config_.qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, layer_config_.qkv_dim); - } - }); + pool_.Run(0, layer_config_.heads * num_tokens_, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % layer_config_.heads; + const size_t token = task / layer_config_.heads; + // Compute Q.K scores, which are "logits" stored in head_att. + float* HWY_RESTRICT q = + activations_.q.Batch(token) + head * 3 * qkv_dim; + MulByConst(query_scale, q, qkv_dim); + float* HWY_RESTRICT head_att = + activations_.att.Batch(token) + head * activations_.seq_len; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT k = + activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim; + head_att[i] = Dot(q, k, qkv_dim); // score = q.k + } + // SoftMax yields "probabilities" in head_att. + Softmax(head_att, seq_len); + // Compute weighted sum of v into att_out. + float* HWY_RESTRICT att_out = + activations_.att_out.Batch(token) + head * qkv_dim; + hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = activations_.q.Batch(i) + + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); + } + }); } // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and @@ -965,6 +961,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, layer_weights->vit.layer_norm_0_scale.data_scale1(), layer_weights->vit.layer_norm_0_bias.data_scale1(), activations.pre_att_rms_out.All(), model_dim); + // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y ~ att_sums VitAttention(num_tokens, layer, activations, layer_weights)(); @@ -1104,8 +1101,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, const ModelWeightsPtrs& weights, Activations& activations) { const size_t model_dim = weights.weights_config.vit_model_dim; - const size_t patch_width = - weights.weights_config.vit_layer_configs[0].patch_width; + const size_t patch_width = weights.weights_config.patch_width; const size_t seq_len = weights.weights_config.vit_seq_len; const size_t patch_size = patch_width * patch_width * 3; HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == @@ -1483,17 +1479,16 @@ void GenerateImageTokensT(const ModelWeightsStorage& model, const Image& image, ImageTokens& image_tokens, PerClusterPools& pools) { if (model.Config().vit_layer_configs.empty()) { - return; - } else { - Activations prefill_activations(model.Config()); - RuntimeConfig prefill_runtime_config = runtime_config; - prefill_runtime_config.prefill_tbatch_size = model.Config().vit_seq_len; - prefill_activations.Allocate(prefill_runtime_config.prefill_tbatch_size, - pools); - // Weights are for the full PaliGemma model, not just the ViT part. - PrefillVit(*model.GetWeightsOfType(), prefill_runtime_config, image, - image_tokens, prefill_activations); + HWY_ABORT("Model does not support generating image tokens."); } + RuntimeConfig prefill_runtime_config = runtime_config; + ModelConfig vit_config = VitConfig(model.Config()); + prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len; + Activations prefill_activations(vit_config); + prefill_activations.Allocate(vit_config.seq_len, pools); + // Weights are for the full PaliGemma model, not just the ViT part. + PrefillVit(*model.GetWeightsOfType(), prefill_runtime_config, image, + image_tokens, prefill_activations); } } // namespace HWY_NAMESPACE diff --git a/gemma/weights.h b/gemma/weights.h index 65a5965..60e9d13 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -349,8 +349,10 @@ struct ModelWeightsPtrs { vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), - vit_img_embedding_kernel("img_emb_kernel", 14 * 14 * 3, - config.vit_model_dim), + vit_img_embedding_kernel( + "img_emb_kernel", + config.patch_width * config.patch_width * 3, + config.vit_model_dim), vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim), vit_img_head_kernel("img_head_kernel", config.vit_model_dim,