diff --git a/gemma/configs.cc b/gemma/configs.cc index d058faf..ebec631 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -291,6 +291,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) { vit_config.seq_len = config.vit_config.seq_len; vit_config.layer_configs = config.vit_config.layer_configs; vit_config.pool_dim = config.vit_config.pool_dim; + vit_config.wrapping = config.wrapping; // The Vit part does not have a vocabulary, the image patches are embedded. vit_config.vocab_size = 0; return vit_config; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index e971fbe..d87be54 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -20,6 +20,7 @@ #include #include // std::min +#include #include #include "compression/compress.h" @@ -610,7 +611,7 @@ class VitAttention { } // TODO(philculliton): transition fully to MatMul. - HWY_NOINLINE void DotSoftmaxWeightedSum() { + HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { const size_t qkv_dim = layer_config_.qkv_dim; const size_t heads = layer_config_.heads; HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); @@ -669,6 +670,44 @@ class VitAttention { } } + HWY_NOINLINE void DotSoftmaxWeightedSum() { + const size_t qkv_dim = layer_config_.qkv_dim; + const size_t heads = layer_config_.heads; + HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); + const size_t seq_len = activations_.seq_len; + const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); + PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); + + // Compute Q.K, softmax, and weighted V. + pool_.Run(0, layer_config_.heads * num_tokens_, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % layer_config_.heads; + const size_t token = task / layer_config_.heads; + // Compute Q.K scores, which are "logits" stored in head_att. + float* HWY_RESTRICT q = + activations_.q.Batch(token) + head * 3 * qkv_dim; + MulByConst(query_scale, q, qkv_dim); + float* HWY_RESTRICT head_att = + activations_.att.Batch(token) + head * activations_.seq_len; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT k = + activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim; + head_att[i] = Dot(q, k, qkv_dim); // score = q.k + } + // SoftMax yields "probabilities" in head_att. + Softmax(head_att, seq_len); + // Compute weighted sum of v into att_out. + float* HWY_RESTRICT att_out = + activations_.att_out.Batch(token) + head * qkv_dim; + hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = activations_.q.Batch(i) + + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); + } + }); + } + // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // head_dim (`qkv_dim`) into output (`att_sums`). HWY_NOINLINE void SumHeads() { @@ -695,7 +734,11 @@ class VitAttention { HWY_INLINE void operator()() { ComputeQKV(); - DotSoftmaxWeightedSum(); + if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { + DotSoftmaxWeightedSumMatrix(); + } else { + DotSoftmaxWeightedSum(); + } SumHeads(); } @@ -1177,11 +1220,13 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, weights.vit_encoder_norm_bias.data_scale1(), activations.x.All(), vit_model_dim); - activations.x = AvgPool4x4(activations.x); + if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { + activations.x = AvgPool4x4(activations.x); - // Apply soft embedding norm before input projection. - RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(), - vit_model_dim); + // Apply soft embedding norm before input projection. + RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(), + vit_model_dim); + } // Apply head embedding into image_tokens of size of the LLM kModelDim. MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x), diff --git a/gemma/run.cc b/gemma/run.cc index e4c2d2e..c0b1eb5 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -217,8 +217,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, timing_info); std::cout << "\n\n"; - // Prepare for the next turn. + // Prepare for the next turn. Works only for PaliGemma. if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { + abs_pos = 0; // Start a new turn at position 0. + InitGenerator(args, gen); + } else { // The last token was either EOS, then it should be ignored because it is // never part of the dialog, see Table 5 in the Gemma-2 paper: // https://arxiv.org/pdf/2408.00118 diff --git a/gemma/weights.h b/gemma/weights.h index ba427ac..5fd544b 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -500,7 +500,9 @@ struct ModelWeightsPtrs { GEMMA_CALL_FUNC(vit_img_pos_embedding); GEMMA_CALL_FUNC(vit_img_head_bias); GEMMA_CALL_FUNC(vit_img_head_kernel); - GEMMA_CALL_FUNC(mm_embed_norm); + + if (ptrs[0]->weights_config.wrapping == PromptWrapping::GEMMA_VLM) + GEMMA_CALL_FUNC(mm_embed_norm); } for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {