mirror of https://github.com/google/gemma.cpp.git
parent
0ff6b3123a
commit
1b1b63d560
|
|
@ -291,6 +291,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
|
|||
vit_config.seq_len = config.vit_config.seq_len;
|
||||
vit_config.layer_configs = config.vit_config.layer_configs;
|
||||
vit_config.pool_dim = config.vit_config.pool_dim;
|
||||
vit_config.wrapping = config.wrapping;
|
||||
// The Vit part does not have a vocabulary, the image patches are embedded.
|
||||
vit_config.vocab_size = 0;
|
||||
return vit_config;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
|
|
@ -610,7 +611,7 @@ class VitAttention {
|
|||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
|
|
@ -669,6 +670,44 @@ class VitAttention {
|
|||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len = activations_.seq_len;
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_,
|
||||
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.q.Batch(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Batch(token) + head * activations_.seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(head_att, seq_len);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.att_out.Batch(token) + head * qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.q.Batch(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||
HWY_NOINLINE void SumHeads() {
|
||||
|
|
@ -695,7 +734,11 @@ class VitAttention {
|
|||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
DotSoftmaxWeightedSum();
|
||||
if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
DotSoftmaxWeightedSumMatrix();
|
||||
} else {
|
||||
DotSoftmaxWeightedSum();
|
||||
}
|
||||
SumHeads();
|
||||
}
|
||||
|
||||
|
|
@ -1177,11 +1220,13 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
weights.vit_encoder_norm_bias.data_scale1(),
|
||||
activations.x.All(), vit_model_dim);
|
||||
|
||||
activations.x = AvgPool4x4(activations.x);
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
activations.x = AvgPool4x4(activations.x);
|
||||
|
||||
// Apply soft embedding norm before input projection.
|
||||
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
|
||||
vit_model_dim);
|
||||
// Apply soft embedding norm before input projection.
|
||||
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
|
||||
vit_model_dim);
|
||||
}
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
||||
|
|
|
|||
|
|
@ -217,8 +217,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
// Prepare for the next turn.
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(args, gen);
|
||||
} else {
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
// https://arxiv.org/pdf/2408.00118
|
||||
|
|
|
|||
|
|
@ -500,7 +500,9 @@ struct ModelWeightsPtrs {
|
|||
GEMMA_CALL_FUNC(vit_img_pos_embedding);
|
||||
GEMMA_CALL_FUNC(vit_img_head_bias);
|
||||
GEMMA_CALL_FUNC(vit_img_head_kernel);
|
||||
GEMMA_CALL_FUNC(mm_embed_norm);
|
||||
|
||||
if (ptrs[0]->weights_config.wrapping == PromptWrapping::GEMMA_VLM)
|
||||
GEMMA_CALL_FUNC(mm_embed_norm);
|
||||
}
|
||||
|
||||
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue