Fix PaliGemma models.

PiperOrigin-RevId: 736483021
This commit is contained in:
Phil Culliton 2025-03-13 06:27:52 -07:00 committed by Copybara-Service
parent 0ff6b3123a
commit 1b1b63d560
4 changed files with 59 additions and 8 deletions

View File

@ -291,6 +291,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
vit_config.seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
vit_config.pool_dim = config.vit_config.pool_dim;
vit_config.wrapping = config.wrapping;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;

View File

@ -20,6 +20,7 @@
#include <stdio.h>
#include <algorithm> // std::min
#include <cstdio>
#include <vector>
#include "compression/compress.h"
@ -610,7 +611,7 @@ class VitAttention {
}
// TODO(philculliton): transition fully to MatMul.
HWY_NOINLINE void DotSoftmaxWeightedSum() {
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
@ -669,6 +670,44 @@ class VitAttention {
}
}
HWY_NOINLINE void DotSoftmaxWeightedSum() {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * activations_.seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.q.Batch(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
}
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() {
@ -695,7 +734,11 @@ class VitAttention {
HWY_INLINE void operator()() {
ComputeQKV();
DotSoftmaxWeightedSum();
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
DotSoftmaxWeightedSumMatrix();
} else {
DotSoftmaxWeightedSum();
}
SumHeads();
}
@ -1177,11 +1220,13 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
weights.vit_encoder_norm_bias.data_scale1(),
activations.x.All(), vit_model_dim);
activations.x = AvgPool4x4(activations.x);
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
vit_model_dim);
// Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
vit_model_dim);
}
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),

View File

@ -217,8 +217,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
timing_info);
std::cout << "\n\n";
// Prepare for the next turn.
// Prepare for the next turn. Works only for PaliGemma.
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
// https://arxiv.org/pdf/2408.00118

View File

@ -500,7 +500,9 @@ struct ModelWeightsPtrs {
GEMMA_CALL_FUNC(vit_img_pos_embedding);
GEMMA_CALL_FUNC(vit_img_head_bias);
GEMMA_CALL_FUNC(vit_img_head_kernel);
GEMMA_CALL_FUNC(mm_embed_norm);
if (ptrs[0]->weights_config.wrapping == PromptWrapping::GEMMA_VLM)
GEMMA_CALL_FUNC(mm_embed_norm);
}
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {