mirror of https://github.com/google/gemma.cpp.git
445 lines
19 KiB
C++
445 lines
19 KiB
C++
// Copyright 2025 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <math.h> // sqrtf
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <vector>
|
|
|
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
|
#include "gemma/flash_structs.h"
|
|
#ifndef HWY_DISABLED_TARGETS
|
|
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
|
#endif // HWY_DISABLED_TARGETS
|
|
|
|
#include "gemma/activations.h"
|
|
#include "gemma/gemma.h"
|
|
#include "gemma/gemma_args.h"
|
|
#include "gemma/weights.h"
|
|
#include "paligemma/image.h"
|
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
#include "hwy/profiler.h"
|
|
|
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
|
// which we pass the filename via macro 'argument'.
|
|
// clang-format off
|
|
#undef HWY_TARGET_INCLUDE
|
|
#define HWY_TARGET_INCLUDE "gemma/vit.cc" // NOLINT
|
|
// clang-format on
|
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
|
#include "hwy/highway.h"
|
|
// After highway.h
|
|
#include "gemma/attention.h"
|
|
#include "gemma/flash_attention.h"
|
|
#include "gemma/gemma-inl.h"
|
|
#include "ops/ops-inl.h"
|
|
|
|
HWY_BEFORE_NAMESPACE();
|
|
namespace gcpp {
|
|
namespace HWY_NAMESPACE {
|
|
|
|
// Wrapper class; holds arguments in member variables to shorten call sites.
|
|
// The main differences to GemmaAttention are:
|
|
// - no KV Cache necessary, attention is always all-to-all and not causal.
|
|
// - no potential wrap-around, attention always goes from 0 to kSeqLen.
|
|
// - no need for batching, as we are always computing attention for kSeqLen
|
|
// tokens.
|
|
// This results in a much simpler implementation. However, to avoid duplicating
|
|
// code, we should still consider merging the two classes.
|
|
// TODO(keysers): Refactor to share code with GemmaAttention.
|
|
class VitAttention {
|
|
// Computes Q, K, V for all heads, stored in activations_.q.
|
|
HWY_NOINLINE void ComputeQKV() {
|
|
PROFILER_ZONE("Gen.VitAttention.QKV");
|
|
auto& qkv = activations_.attention.q;
|
|
HWY_ASSERT(qkv.Rows() == num_tokens_);
|
|
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
|
CallMatMul(activations_.attention.pre_att_rms_out, layer_.vit.qkv_einsum_w,
|
|
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
|
|
}
|
|
|
|
// Applies the query scale to the query and converts to QType.
|
|
template <typename QKVType, typename QType>
|
|
void ScaleQuery(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
|
|
const size_t heads, const size_t qkv_dim,
|
|
const float query_scale, MatPtrT<QType>& q_output) {
|
|
ParallelFor(Parallelism::kWithinCluster, heads, env_.ctx,
|
|
/*cluster_idx=*/0, Callers::kFlashAttention,
|
|
[&](size_t head, size_t worker) {
|
|
size_t q_offset = head * qkv_dim;
|
|
for (size_t token = 0; token < num_tokens; ++token) {
|
|
const float* HWY_RESTRICT src_q =
|
|
qkv.Row(token) + q_offset * 3;
|
|
QType* HWY_RESTRICT dst_q = q_output.Row(token) + q_offset;
|
|
for (size_t i = 0; i < qkv_dim; ++i) {
|
|
dst_q[i] = hwy::ConvertScalarTo<QType>(
|
|
hwy::ConvertScalarTo<float>(src_q[i]) * query_scale);
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// Transposes K and V and converts to KVType.
|
|
template <typename QKVType, typename KVType>
|
|
void TransposeKAndV(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
|
|
const size_t heads, const size_t qkv_dim,
|
|
MatPtrT<KVType>& k_output, MatPtrT<KVType>& v_output) {
|
|
using DF = hn::ScalableTag<float>;
|
|
const DF df;
|
|
const size_t kNF = hn::Lanes(df);
|
|
const size_t kNumTokensH = hwy::DivCeil(num_tokens, 2 * kNF);
|
|
const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
|
|
ParallelFor(
|
|
Parallelism::kWithinCluster, heads, env_.ctx,
|
|
/*cluster_idx=*/0, Callers::kFlashAttention,
|
|
[&](size_t head, size_t worker) {
|
|
const size_t qkv_offset = head * 3 * qkv_dim;
|
|
const size_t k_or_v_offset = head * 2 * kNF * kRoundedKVDim;
|
|
for (size_t token_h = 0; token_h < kNumTokensH; ++token_h) {
|
|
KVType* HWY_RESTRICT dst_k = k_output.Row(token_h);
|
|
KVType* HWY_RESTRICT dst_v = v_output.Row(token_h);
|
|
size_t dst_k_index = k_or_v_offset;
|
|
for (size_t q = 0; q < qkv_dim; q += 2) {
|
|
for (size_t token_l = 0; token_l < 2 * kNF;
|
|
++token_l, dst_k_index += 2) {
|
|
const QKVType* HWY_RESTRICT src_k =
|
|
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset + qkv_dim;
|
|
dst_k[dst_k_index] = hwy::ConvertScalarTo<KVType>(src_k[q]);
|
|
dst_k[dst_k_index + 1] =
|
|
hwy::ConvertScalarTo<KVType>(src_k[q + 1]);
|
|
}
|
|
}
|
|
size_t dst_v_index = k_or_v_offset;
|
|
for (size_t q = 0; q < qkv_dim; q += 2 * kNF) {
|
|
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
|
|
const QKVType* HWY_RESTRICT src_v =
|
|
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset +
|
|
qkv_dim * 2;
|
|
if (q + 2 * kNF <= qkv_dim) {
|
|
for (size_t q_l = 0; q_l < 2 * kNF; ++q_l) {
|
|
dst_v[dst_v_index++] =
|
|
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
|
|
}
|
|
} else {
|
|
for (size_t q_l = 0; q_l < qkv_dim - q; ++q_l) {
|
|
dst_v[dst_v_index++] =
|
|
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Zero out the padding area.
|
|
// In the loops above, the dst_k loop has written 2kNF x 2
|
|
// consecutive elements for each q +=2, and the dst_v loop has
|
|
// written 2kNF x 2kNF consecutive elements for each q += 2 * kNF.
|
|
// Both of them therefore write 2kNF elements for each increment of
|
|
// q, so we can combine both into a single loop for the padding.
|
|
// This could be further simplified by writing a zero vector.
|
|
for (size_t q = qkv_dim; q < kRoundedKVDim; ++q) {
|
|
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
|
|
dst_k[dst_k_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
|
|
dst_v[dst_v_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// Computes the flash attention parameters. This is mostly about deciding on
|
|
// the tile sizes and filling the param structs with the correct offsets.
|
|
template <typename QType, typename KVType>
|
|
void ComputeParams(const uint32_t num_tokens, const size_t seq_len,
|
|
const size_t heads, const uint32_t qkv_dim,
|
|
const MatPtrT<QType>& q, const MatPtrT<KVType>& k,
|
|
const MatPtrT<KVType>& v, const MatPtrT<float>& att_out,
|
|
std::vector<Tile148Params>& flash_params) {
|
|
flash_params.clear();
|
|
for (uint32_t head = 0; head < heads; ++head) {
|
|
uint32_t token = 0;
|
|
while (token + k8xNFVTileSize <= num_tokens) {
|
|
flash_params.push_back(Tile148Params{
|
|
.v_tile_size = k8xNFVTileSize,
|
|
.qi_index = token,
|
|
.kv_head = head,
|
|
});
|
|
token += k8xNFVTileSize;
|
|
}
|
|
if (token + k4xNFVTileSize <= num_tokens) {
|
|
flash_params.push_back(Tile148Params{
|
|
.v_tile_size = k4xNFVTileSize,
|
|
.qi_index = token,
|
|
.kv_head = head,
|
|
});
|
|
token += k4xNFVTileSize;
|
|
}
|
|
while (token < num_tokens) {
|
|
flash_params.push_back(Tile148Params{
|
|
.v_tile_size = 1,
|
|
.qi_index = token,
|
|
.kv_head = head,
|
|
});
|
|
token += 1;
|
|
}
|
|
}
|
|
for (auto& param : flash_params) {
|
|
param.min_start_pos = 0;
|
|
param.max_last_pos = num_tokens - 1;
|
|
for (size_t i = 0; i < param.v_tile_size; ++i) {
|
|
param.q_offsets[i] =
|
|
q.Row(param.qi_index + i) + param.kv_head * qkv_dim - q.Row(0);
|
|
param.out_offsets[i] = att_out.Row(param.qi_index + i) +
|
|
param.kv_head * qkv_dim - att_out.Row(0);
|
|
param.start_pos[i] = 0;
|
|
param.last_pos[i] = num_tokens - 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Runs the flash attention algorithm on Q, K, V.
|
|
HWY_NOINLINE void FlashAttention() {
|
|
GCPP_ZONE(env_.ctx, 0, Zones::kVitFlashAttentionInclusive);
|
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
|
const size_t heads = layer_config_.heads;
|
|
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
|
const size_t kNF = FloatsPerVector();
|
|
const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
|
|
auto& attn = activations_.attention;
|
|
const size_t seq_len = static_cast<size_t>(attn.div_seq_len.GetDivisor());
|
|
if (attn.vit_K_T.Rows() >= seq_len) {
|
|
attn.vit_K_T.ReshapePackedRowsToCols(2 * kNF);
|
|
attn.vit_V_T.ReshapePackedRowsToCols(2 * kNF);
|
|
}
|
|
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
|
ScaleQuery(attn.q, num_tokens_, heads, qkv_dim, query_scale, attn.q_bf);
|
|
TransposeKAndV(attn.q, num_tokens_, heads, qkv_dim, attn.vit_K_T,
|
|
attn.vit_V_T);
|
|
ComputeParams(num_tokens_, seq_len, heads, qkv_dim, attn.q_bf, attn.vit_K_T,
|
|
attn.vit_V_T, attn.att_out, attn.flash_params);
|
|
size_t num_tasks = attn.flash_params.size();
|
|
|
|
// For each param, compute fused flash Q.K, softmax and weighted V.
|
|
const auto func = [&, &ctx = env_.ctx](const size_t task,
|
|
size_t worker) HWY_ATTR {
|
|
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
|
|
auto& param = attn.flash_params[task];
|
|
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
|
|
kRoundedKVDim * 2 * kNF));
|
|
kT.SetPtr(attn.vit_K_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
|
|
attn.vit_K_T.Stride());
|
|
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
|
|
kRoundedKVDim * 2 * kNF));
|
|
vT.SetPtr(attn.vit_V_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
|
|
attn.vit_V_T.Stride());
|
|
DispatchDispatchTileFlashAttention148(
|
|
param, attn.q_bf, kT, vT, /*layer_idx=*/0, attn, attn.att_out,
|
|
qkv_dim, ctx, worker, /*attention_impl=*/AttentionImpl::kFlash);
|
|
};
|
|
|
|
{
|
|
PROFILER_ZONE("Gen.VitFlashAttention.ForkJoin");
|
|
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
|
HierarchicalParallelFor(num_tasks, env_.ctx, Callers::kFlashAttention,
|
|
func);
|
|
}
|
|
}
|
|
|
|
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
|
// head_dim (`qkv_dim`) into output (`att_sums`).
|
|
HWY_NOINLINE void SumHeads() {
|
|
auto* bias = layer_.vit.attn_out_b.PackedScale1();
|
|
// att_weights and att_out are concatenated heads, each of length
|
|
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
|
// matmul output is the sum over heads.
|
|
CallMatMul(activations_.attention.att_out, layer_.vit.attn_out_w, bias,
|
|
env_, activations_.attention.att_sums);
|
|
}
|
|
|
|
public:
|
|
VitAttention(size_t num_tokens, size_t layer_idx, Activations& activations,
|
|
const LayerWeightsPtrs& layer, MatMulEnv& env)
|
|
: num_tokens_(num_tokens),
|
|
activations_(activations),
|
|
layer_(layer),
|
|
layer_config_(layer.layer_config),
|
|
env_(env),
|
|
pool_(env_.ctx.pools.Pool(0)),
|
|
caller1_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax1)),
|
|
caller2_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax2)),
|
|
caller3_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax3)),
|
|
caller4_(env_.ctx.pool_callers.Get(Callers::kVitDotSoftmax4)) {}
|
|
|
|
HWY_INLINE void operator()() {
|
|
ComputeQKV();
|
|
FlashAttention();
|
|
SumHeads();
|
|
}
|
|
|
|
private:
|
|
const size_t num_tokens_;
|
|
Activations& activations_;
|
|
const LayerWeightsPtrs& layer_;
|
|
const LayerConfig& layer_config_;
|
|
MatMulEnv& env_;
|
|
hwy::ThreadPool& pool_;
|
|
hwy::pool::Caller caller1_;
|
|
hwy::pool::Caller caller2_;
|
|
hwy::pool::Caller caller3_;
|
|
hwy::pool::Caller caller4_;
|
|
};
|
|
|
|
// Same as FFWNoVit, but with different layer members and no second
|
|
// gating matrix.
|
|
void FFWVit(const LayerWeightsPtrs& layer, Activations& activations,
|
|
MatMulEnv& env) {
|
|
PROFILER_ZONE("Gen.FFW.ViT");
|
|
const LayerConfig& layer_config = layer.layer_config;
|
|
|
|
const bool add_bias = layer_config.ff_biases;
|
|
const float* bias1 = add_bias ? layer.vit.linear_0_b.PackedScale1() : nullptr;
|
|
const float* output_bias =
|
|
add_bias ? layer.vit.linear_1_b.PackedScale1() : nullptr;
|
|
|
|
// Compute the hidden layer activations.
|
|
CallMatMul(activations.pre_ffw_rms_out, layer.vit.linear_0_w, bias1, env,
|
|
activations.C1);
|
|
|
|
// Activation (Gelu), store in C1.
|
|
ActivationBatched(layer_config.activation, activations.C1, env.ctx);
|
|
|
|
// Hidden layer -> output layer.
|
|
CallMatMul(activations.C1, layer.vit.linear_1_w, output_bias, env,
|
|
activations.ffw_out);
|
|
}
|
|
|
|
// Vit transformer layer. Some comments below refer to the Vit implementation in
|
|
// the Big Vision codebase. See
|
|
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
|
|
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
|
|
// try merging this with TransformerLayer.
|
|
void VitTransformerLayer(size_t num_tokens, const size_t layer_idx,
|
|
const LayerWeightsPtrs& layer,
|
|
Activations& activations, MatMulEnv& env) {
|
|
const size_t model_dim = activations.attention.config.model_dim;
|
|
auto type = layer.layer_config.type;
|
|
HWY_DASSERT(type == LayerAttentionType::kVit);
|
|
(void)type;
|
|
(void)model_dim;
|
|
|
|
auto& x = activations.x;
|
|
HWY_DASSERT(x.Rows() == num_tokens);
|
|
HWY_DASSERT(x.Cols() == model_dim);
|
|
|
|
// y = nn.LayerNorm()(x)
|
|
// y ~ pre_att_rms_out
|
|
LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias,
|
|
activations.attention.pre_att_rms_out);
|
|
|
|
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
|
// y ~ att_sums
|
|
VitAttention(num_tokens, layer_idx, activations, layer, env)();
|
|
|
|
// x = out["+sa"] = x + y
|
|
AddFromBatched(activations.attention.att_sums, x, env.ctx);
|
|
|
|
// y = nn.LayerNorm()(x)
|
|
// y ~ pre_ffw_rms_out
|
|
LayerNormBatched(x, layer.vit.layer_norm_1_scale, layer.vit.layer_norm_1_bias,
|
|
activations.pre_ffw_rms_out);
|
|
|
|
// y = out["mlp"] = MlpBlock(...)(y)
|
|
// y ~ ffw_out
|
|
FFWVit(layer, activations, env);
|
|
|
|
// x = out["+mlp"] = x + y
|
|
AddFromBatched(activations.ffw_out, x, env.ctx);
|
|
}
|
|
|
|
// Gets the patches of the image and embeds them with the image embedding
|
|
// kernel. The result is stored in activations.x.
|
|
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|
const ModelConfig& model_config,
|
|
const WeightsPtrs& weights,
|
|
Activations& activations,
|
|
MatMulEnv& env) {
|
|
const size_t model_dim = model_config.vit_config.model_dim;
|
|
const size_t patch_width = model_config.vit_config.patch_width;
|
|
const size_t num_tokens = model_config.vit_config.seq_len;
|
|
const size_t patch_area = patch_width * patch_width * 3;
|
|
const hwy::Divisor div_patch_dim(patch_width);
|
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_area);
|
|
HWY_DASSERT(activations.x.Cols() == model_dim);
|
|
(void)model_dim;
|
|
// img/embedding/kernel has original shape (14, 14, 3, 1152)
|
|
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
|
// image_patches is (256, 14 * 14 * 3)
|
|
// Must be padded, see `DoDecompressA`.
|
|
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_area),
|
|
env.ctx.allocator, MatPadding::kOdd);
|
|
for (size_t i = 0; i < num_tokens; ++i) {
|
|
image.GetPatch(i, div_patch_dim, image_patches.Row(i));
|
|
}
|
|
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
|
weights.vit_img_embedding_bias.PackedScale1(), env, activations.x);
|
|
// Add position embeddings.
|
|
CallUpcastedActivation(&weights.vit_img_pos_embedding,
|
|
[&](const auto* weights_t) {
|
|
AddFromBatched(*weights_t, activations.x, env.ctx);
|
|
});
|
|
}
|
|
|
|
// Prefills the image tokens with the ViT encoder.
|
|
void PrefillVit(const ModelConfig& model_config, const WeightsPtrs& weights,
|
|
const RuntimeConfig& runtime_config, const Image& image,
|
|
ImageTokens& image_tokens, Activations& activations,
|
|
MatMulEnv& env) {
|
|
PROFILER_ZONE("Gen.PrefillVit");
|
|
const size_t num_tokens = model_config.vit_config.seq_len;
|
|
const size_t vit_model_dim = model_config.vit_config.model_dim;
|
|
HWY_ASSERT(num_tokens == activations.x.Rows());
|
|
// Embed the image patches.
|
|
EmbedImagePatches(image, model_config, weights, activations, env);
|
|
// Go through all layers.
|
|
for (size_t layer_idx = 0;
|
|
layer_idx < model_config.vit_config.layer_configs.size(); ++layer_idx) {
|
|
VitTransformerLayer(num_tokens, layer_idx, *weights.VitLayer(layer_idx),
|
|
activations, env);
|
|
}
|
|
// Final Layernorm.
|
|
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
|
|
weights.vit_encoder_norm_bias, activations.x);
|
|
|
|
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
|
activations.x = AvgPool4x4(activations.x, env.ctx.allocator);
|
|
|
|
// Apply soft embedding norm before input projection.
|
|
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
|
|
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0,
|
|
activations.x.Row(0), vit_model_dim, env.ctx,
|
|
hwy::Profiler::GlobalIdx());
|
|
});
|
|
}
|
|
|
|
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
|
CallMatMul(activations.x, weights.vit_img_head_kernel,
|
|
weights.vit_img_head_bias.PackedScale1(), env, image_tokens);
|
|
}
|
|
|
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
|
} // namespace HWY_NAMESPACE
|
|
} // namespace gcpp
|
|
HWY_AFTER_NAMESPACE();
|