Replaced attention in ViT with flash - 8x speedup of image tokenizer on AMD

PiperOrigin-RevId: 880877209
This commit is contained in:
Ray Smith 2026-03-09 08:45:29 -07:00 committed by Copybara-Service
parent 029cfd0b33
commit bea8b1cdbd
9 changed files with 303 additions and 148 deletions

View File

@ -555,6 +555,7 @@ cc_library(
":ops", ":ops",
":tensor_stats", ":tensor_stats",
":threading_context", ":threading_context",
"@highway//:abort_header_only",
], ],
) )
@ -678,6 +679,7 @@ cc_library(
":attention", ":attention",
":basics", ":basics",
":configs", ":configs",
":flash_structs",
":gemma_args", ":gemma_args",
":kv_cache", ":kv_cache",
":mat", ":mat",

View File

@ -76,8 +76,16 @@ struct AttentionActivations {
: batch_size * layer_config.heads, : batch_size * layer_config.heads,
allocator)), allocator)),
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)), vit_K_T(MatFactory(
vit_C(MatFactory("C2", batch_size, seq_len, allocator)), "K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
layer_config.heads *
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
allocator, MatPadding::kPacked)),
vit_V_T(MatFactory(
"V2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
layer_config.heads *
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
allocator, MatPadding::kPacked)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)), config.model_dim, allocator)),
// att is only valid for AttentionImpl::kOld. // att is only valid for AttentionImpl::kOld.
@ -126,7 +134,6 @@ struct AttentionActivations {
q.AllocateAndAttachRowPtrs(row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs);
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs);
} }
@ -136,8 +143,7 @@ struct AttentionActivations {
// q_T rows are always qkv_dim! // q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size); vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len! // vit_K_T and vit_V_T stay seq_len!
vit_C.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
@ -167,8 +173,8 @@ struct AttentionActivations {
MatStorageT<BF16> q_T; // Transposed to maximize attention speed. MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> vit_Q; MatStorageT<float> vit_Q;
MatStorageT<float> vit_K; MatStorageT<KV_t> vit_K_T;
MatStorageT<float> vit_C; MatStorageT<KV_t> vit_V_T;
MatStorageT<float> pre_att_rms_out; MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector MatStorageT<float> att; // attention vector
@ -214,8 +220,8 @@ struct AttentionActivationsPtrs {
q_bf = activations.q_bf; q_bf = activations.q_bf;
q_T = activations.q_T; q_T = activations.q_T;
vit_Q = activations.vit_Q; vit_Q = activations.vit_Q;
vit_K = activations.vit_K; vit_K_T = activations.vit_K_T;
vit_C = activations.vit_C; vit_V_T = activations.vit_V_T;
pre_att_rms_out = activations.pre_att_rms_out; pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att; att = activations.att;
att_out = activations.att_out; att_out = activations.att_out;
@ -233,8 +239,7 @@ struct AttentionActivationsPtrs {
// q_T rows are always qkv_dim! // q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size); vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len! // vit_K_T and vit_V_T stay seq_len!
vit_C.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size); att.OverrideRows(batch_size);
@ -267,8 +272,8 @@ struct AttentionActivationsPtrs {
MatPtrT<BF16> q_T; MatPtrT<BF16> q_T;
MatPtrT<float> vit_Q; MatPtrT<float> vit_Q;
MatPtrT<float> vit_K; MatPtrT<KV_t> vit_K_T;
MatPtrT<float> vit_C; MatPtrT<KV_t> vit_V_T;
// Output of RMSNorm before attention, size batch_size x model_dim. // Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out; MatPtrT<float> pre_att_rms_out;

View File

@ -2260,3 +2260,21 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp
HWY_AFTER_NAMESPACE(); HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(DispatchTileFlashAttention148);
void DispatchDispatchTileFlashAttention148(
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
AttentionImpl attention_impl) {
HWY_DYNAMIC_DISPATCH(DispatchTileFlashAttention148)(
params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker,
attention_impl);
}
} // namespace gcpp
#endif

View File

@ -42,14 +42,6 @@ namespace gcpp {
const MatPtr& query_norm_scale, size_t layer_idx, \ const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\ \
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \ size_t total_tasks, size_t target_parallelism); \
\ \
@ -83,6 +75,13 @@ HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
#undef GEMMA_DECL_FLASH_ATTENTION #undef GEMMA_DECL_FLASH_ATTENTION
void DispatchDispatchTileFlashAttention148(
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
AttentionImpl attention_impl);
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_

View File

@ -544,8 +544,6 @@ void TestAttentionMultipleTokens() {
test_env.SetupWeights(); test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q); FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max); FillMatPtrT(test_env.activations->attention.softmax_max);
@ -590,8 +588,6 @@ void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
test_env.SetupWeights(); test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q); FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max); FillMatPtrT(test_env.activations->attention.softmax_max);
@ -763,8 +759,6 @@ void TestAttentionMultipleTokensBF16() {
test_env.SetupWeights(); test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q); FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max); FillMatPtrT(test_env.activations->attention.softmax_max);
@ -807,8 +801,6 @@ void TestAttentionMultipleTokensInt8() {
test_env.SetupWeights(); test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out); FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q); FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att); FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out); FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max); FillMatPtrT(test_env.activations->attention.softmax_max);

View File

@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS
@ -41,6 +42,8 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "gemma/attention.h"
#include "gemma/flash_attention.h"
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
@ -68,107 +71,194 @@ class VitAttention {
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv); layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
} }
// TODO(philculliton): transition fully to MatMul. // Applies the query scale to the query and converts to QType.
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() { template <typename QKVType, typename QType>
const size_t qkv_dim = layer_config_.qkv_dim; void ScaleQuery(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
const size_t heads = layer_config_.heads; const size_t heads, const size_t qkv_dim,
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); const float query_scale, MatPtrT<QType>& q_output) {
const size_t seq_len = ParallelFor(Parallelism::kWithinCluster, heads, env_.ctx,
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor()); /*cluster_idx=*/0, Callers::kFlashAttention,
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim)); [&](size_t head, size_t worker) {
PROFILER_ZONE("Gen.VitAttention.DotSoftmaxMatrix"); size_t q_offset = head * qkv_dim;
for (size_t token = 0; token < num_tokens; ++token) {
MatPtrT<float>& Q = activations_.attention.vit_Q; const float* HWY_RESTRICT src_q =
MatPtrT<float>& K = activations_.attention.vit_K; qkv.Row(token) + q_offset * 3;
MatPtrT<float>& C = activations_.attention.vit_C; QType* HWY_RESTRICT dst_q = q_output.Row(token) + q_offset;
for (size_t i = 0; i < qkv_dim; ++i) {
// Initialize att_out to zero prior to head loop. dst_q[i] = hwy::ConvertScalarTo<QType>(
ZeroInit(activations_.attention.att_out); hwy::ConvertScalarTo<float>(src_q[i]) * query_scale);
}
for (size_t head = 0; head < heads; ++head) { }
pool_.Run(0, num_tokens_, caller1_,
[&](uint64_t task, size_t worker) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed
// working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
}); });
}
pool_.Run( // Transposes K and V and converts to KVType.
0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { template <typename QKVType, typename KVType>
const size_t seq_idx = task; void TransposeKAndV(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + const size_t heads, const size_t qkv_dim,
head * 3 * qkv_dim + qkv_dim; MatPtrT<KVType>& k_output, MatPtrT<KVType>& v_output) {
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); using DF = hn::ScalableTag<float>;
}); const DF df;
const size_t kNF = hn::Lanes(df);
// this produces C, a (num_tokens_, seq_len) matrix of dot products const size_t kNumTokensH = hwy::DivCeil(num_tokens, 2 * kNF);
CallMatMul(Q, K, nullptr, env_, C); const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
ParallelFor(
pool_.Run(0, num_tokens_, caller3_, Parallelism::kWithinCluster, heads, env_.ctx,
[&](uint64_t task, size_t worker) /*cluster_idx=*/0, Callers::kFlashAttention,
HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); }); [&](size_t head, size_t worker) {
const size_t qkv_offset = head * 3 * qkv_dim;
pool_.Run( const size_t k_or_v_offset = head * 2 * kNF * kRoundedKVDim;
0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR { for (size_t token_h = 0; token_h < kNumTokensH; ++token_h) {
size_t token = task; KVType* HWY_RESTRICT dst_k = k_output.Row(token_h);
float* HWY_RESTRICT att_out = KVType* HWY_RESTRICT dst_v = v_output.Row(token_h);
activations_.attention.att_out.Row(token) + head * qkv_dim; size_t dst_k_index = k_or_v_offset;
for (size_t i = 0; i < seq_len; ++i) { for (size_t q = 0; q < qkv_dim; q += 2) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + for (size_t token_l = 0; token_l < 2 * kNF;
head * 3 * qkv_dim + 2 * qkv_dim; ++token_l, dst_k_index += 2) {
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); const QKVType* HWY_RESTRICT src_k =
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset + qkv_dim;
dst_k[dst_k_index] = hwy::ConvertScalarTo<KVType>(src_k[q]);
dst_k[dst_k_index + 1] =
hwy::ConvertScalarTo<KVType>(src_k[q + 1]);
}
} }
}); size_t dst_v_index = k_or_v_offset;
for (size_t q = 0; q < qkv_dim; q += 2 * kNF) {
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
const QKVType* HWY_RESTRICT src_v =
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset +
qkv_dim * 2;
if (q + 2 * kNF <= qkv_dim) {
for (size_t q_l = 0; q_l < 2 * kNF; ++q_l) {
dst_v[dst_v_index++] =
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
}
} else {
for (size_t q_l = 0; q_l < qkv_dim - q; ++q_l) {
dst_v[dst_v_index++] =
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
}
}
}
}
// Zero out the padding area.
// In the loops above, the dst_k loop has written 2kNF x 2
// consecutive elements for each q +=2, and the dst_v loop has
// written 2kNF x 2kNF consecutive elements for each q += 2 * kNF.
// Both of them therefore write 2kNF elements for each increment of
// q, so we can combine both into a single loop for the padding.
// This could be further simplified by writing a zero vector.
for (size_t q = qkv_dim; q < kRoundedKVDim; ++q) {
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
dst_k[dst_k_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
dst_v[dst_v_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
}
}
}
});
}
// Computes the flash attention parameters. This is mostly about deciding on
// the tile sizes and filling the param structs with the correct offsets.
template <typename QType, typename KVType>
void ComputeParams(const uint32_t num_tokens, const size_t seq_len,
const size_t heads, const uint32_t qkv_dim,
const MatPtrT<QType>& q, const MatPtrT<KVType>& k,
const MatPtrT<KVType>& v, const MatPtrT<float>& att_out,
std::vector<Tile148Params>& flash_params) {
flash_params.clear();
for (uint32_t head = 0; head < heads; ++head) {
uint32_t token = 0;
while (token + k8xNFVTileSize <= num_tokens) {
flash_params.push_back(Tile148Params{
.v_tile_size = k8xNFVTileSize,
.qi_index = token,
.kv_head = head,
});
token += k8xNFVTileSize;
}
if (token + k4xNFVTileSize <= num_tokens) {
flash_params.push_back(Tile148Params{
.v_tile_size = k4xNFVTileSize,
.qi_index = token,
.kv_head = head,
});
token += k4xNFVTileSize;
}
while (token < num_tokens) {
flash_params.push_back(Tile148Params{
.v_tile_size = 1,
.qi_index = token,
.kv_head = head,
});
token += 1;
}
}
for (auto& param : flash_params) {
param.min_start_pos = 0;
param.max_last_pos = num_tokens - 1;
for (size_t i = 0; i < param.v_tile_size; ++i) {
param.q_offsets[i] =
q.Row(param.qi_index + i) + param.kv_head * qkv_dim - q.Row(0);
param.out_offsets[i] = att_out.Row(param.qi_index + i) +
param.kv_head * qkv_dim - att_out.Row(0);
param.start_pos[i] = 0;
param.last_pos[i] = num_tokens - 1;
}
} }
} }
HWY_NOINLINE void DotSoftmaxWeightedSum() { // Runs the flash attention algorithm on Q, K, V.
HWY_NOINLINE void FlashAttention() {
GCPP_ZONE(env_.ctx, 0, Zones::kVitFlashAttentionInclusive);
const size_t qkv_dim = layer_config_.qkv_dim; const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads; const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = const size_t kNF = FloatsPerVector();
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor()); const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
auto& attn = activations_.attention;
const size_t seq_len = static_cast<size_t>(attn.div_seq_len.GetDivisor());
if (attn.vit_K_T.Rows() >= seq_len) {
attn.vit_K_T.ReshapePackedRowsToCols(2 * kNF);
attn.vit_V_T.ReshapePackedRowsToCols(2 * kNF);
}
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim)); const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); ScaleQuery(attn.q, num_tokens_, heads, qkv_dim, query_scale, attn.q_bf);
TransposeKAndV(attn.q, num_tokens_, heads, qkv_dim, attn.vit_K_T,
attn.vit_V_T);
ComputeParams(num_tokens_, seq_len, heads, qkv_dim, attn.q_bf, attn.vit_K_T,
attn.vit_V_T, attn.att_out, attn.flash_params);
size_t num_tasks = attn.flash_params.size();
// Compute Q.K, softmax, and weighted V. // For each param, compute fused flash Q.K, softmax and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_, caller1_, const auto func = [&, &ctx = env_.ctx](const size_t task,
[&](uint64_t task, size_t worker) HWY_ATTR { size_t worker) HWY_ATTR {
const size_t head = task % layer_config_.heads; GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
const size_t token = task / layer_config_.heads; auto& param = attn.flash_params[task];
// Compute Q.K scores, which are "logits" stored in head_att. MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
float* HWY_RESTRICT q = kRoundedKVDim * 2 * kNF));
activations_.attention.q.Row(token) + head * 3 * qkv_dim; kT.SetPtr(attn.vit_K_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
MulByConst(query_scale, q, qkv_dim); attn.vit_K_T.Stride());
float* HWY_RESTRICT head_att = MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
activations_.attention.att.Row(token) + head * seq_len; kRoundedKVDim * 2 * kNF));
for (size_t i = 0; i < seq_len; ++i) { vT.SetPtr(attn.vit_V_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
float* HWY_RESTRICT k = activations_.attention.q.Row(i) + attn.vit_V_T.Stride());
head * 3 * qkv_dim + qkv_dim; DispatchDispatchTileFlashAttention148(
head_att[i] = Dot(q, k, qkv_dim); // score = q.k param, attn.q_bf, kT, vT, /*layer_idx=*/0, attn, attn.att_out,
} qkv_dim, ctx, worker, /*attention_impl=*/AttentionImpl::kFlash);
// SoftMax yields "probabilities" in head_att. };
Softmax(Logits(head_att, seq_len), env_.ctx, worker);
// Compute weighted sum of v into att_out. {
float* HWY_RESTRICT att_out = PROFILER_ZONE("Gen.VitFlashAttention.ForkJoin");
activations_.attention.att_out.Row(token) + head * qkv_dim; // Full parallelism is helpful, SmallParallelFor is insufficient.
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); HierarchicalParallelFor(num_tasks, env_.ctx, Callers::kFlashAttention,
for (size_t i = 0; i < seq_len; ++i) { func);
float* HWY_RESTRICT v = activations_.attention.q.Row(i) + }
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
} }
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`att_sums`). // head_dim (`qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() { HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_.vit.attn_out_b.PackedScale1(); auto* bias = layer_.vit.attn_out_b.PackedScale1();
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
@ -193,11 +283,7 @@ class VitAttention {
HWY_INLINE void operator()() { HWY_INLINE void operator()() {
ComputeQKV(); ComputeQKV();
if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) { FlashAttention();
DotSoftmaxWeightedSumMatrix();
} else {
DotSoftmaxWeightedSum();
}
SumHeads(); SumHeads();
} }

View File

@ -669,10 +669,10 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
size_t i = 0; size_t i = 0;
while (i + NF * 2 <= size) { while (i + NF * 2 <= size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b; VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::Load(df, out + i + out_offsets[0]); out0a = hn::LoadU(df, out + i + out_offsets[0]);
out1a = hn::Load(df, out + i + out_offsets[1]); out1a = hn::LoadU(df, out + i + out_offsets[1]);
out2a = hn::Load(df, out + i + out_offsets[2]); out2a = hn::LoadU(df, out + i + out_offsets[2]);
out3a = hn::Load(df, out + i + out_offsets[3]); out3a = hn::LoadU(df, out + i + out_offsets[3]);
VF scale0 = hn::Set(df, scales[0]); VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]); VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]); VF scale2 = hn::Set(df, scales[2]);
@ -681,28 +681,70 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
out1a = hn::Mul(out1a, scale1); out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2); out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3); out3a = hn::Mul(out3a, scale3);
out0b = hn::Load(df, out + i + NF + out_offsets[0]); out0b = hn::LoadU(df, out + i + NF + out_offsets[0]);
out1b = hn::Load(df, out + i + NF + out_offsets[1]); out1b = hn::LoadU(df, out + i + NF + out_offsets[1]);
out2b = hn::Load(df, out + i + NF + out_offsets[2]); out2b = hn::LoadU(df, out + i + NF + out_offsets[2]);
out3b = hn::Load(df, out + i + NF + out_offsets[3]); out3b = hn::LoadU(df, out + i + NF + out_offsets[3]);
out0b = hn::Mul(out0b, scale0); out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1); out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2); out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3); out3b = hn::Mul(out3b, scale3);
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a, MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b); out2a, out3a, out0b, out1b, out2b, out3b);
hn::Store(out0a, df, out + i + out_offsets[0]); hn::StoreU(out0a, df, out + i + out_offsets[0]);
hn::Store(out1a, df, out + i + out_offsets[1]); hn::StoreU(out1a, df, out + i + out_offsets[1]);
hn::Store(out2a, df, out + i + out_offsets[2]); hn::StoreU(out2a, df, out + i + out_offsets[2]);
hn::Store(out3a, df, out + i + out_offsets[3]); hn::StoreU(out3a, df, out + i + out_offsets[3]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]); hn::StoreU(out0b, df, out + i + NF + out_offsets[0]);
hn::Store(out1b, df, out + i + NF + out_offsets[1]); hn::StoreU(out1b, df, out + i + NF + out_offsets[1]);
hn::Store(out2b, df, out + i + NF + out_offsets[2]); hn::StoreU(out2b, df, out + i + NF + out_offsets[2]);
hn::Store(out3b, df, out + i + NF + out_offsets[3]); hn::StoreU(out3b, df, out + i + NF + out_offsets[3]);
i += NF * 2; i += NF * 2;
v_bf += 4 * NF * NF; v_bf += 4 * NF * NF;
} }
HWY_DASSERT(size == i); if (i < size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::LoadN(df, out + i + out_offsets[0], size - i);
out1a = hn::LoadN(df, out + i + out_offsets[1], size - i);
out2a = hn::LoadN(df, out + i + out_offsets[2], size - i);
out3a = hn::LoadN(df, out + i + out_offsets[3], size - i);
VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]);
VF scale3 = hn::Set(df, scales[3]);
out0a = hn::Mul(out0a, scale0);
out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3);
if (i + NF < size) {
out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF);
out1b = hn::LoadN(df, out + i + NF + out_offsets[1], size - i - NF);
out2b = hn::LoadN(df, out + i + NF + out_offsets[2], size - i - NF);
out3b = hn::LoadN(df, out + i + NF + out_offsets[3], size - i - NF);
out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3);
} else {
out0b = hn::Zero(df);
out1b = hn::Zero(df);
out2b = hn::Zero(df);
out3b = hn::Zero(df);
}
// Note that v_bf is always padded, so we can always load 2 * NF elements.
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b);
hn::StoreN(out0a, df, out + i + out_offsets[0], size - i);
hn::StoreN(out1a, df, out + i + out_offsets[1], size - i);
hn::StoreN(out2a, df, out + i + out_offsets[2], size - i);
hn::StoreN(out3a, df, out + i + out_offsets[3], size - i);
if (i + NF < size) {
hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF);
hn::StoreN(out1b, df, out + i + NF + out_offsets[1], size - i - NF);
hn::StoreN(out2b, df, out + i + NF + out_offsets[2], size - i - NF);
hn::StoreN(out3b, df, out + i + NF + out_offsets[3], size - i - NF);
}
}
} }
template <class DF, class VF = hn::Vec<DF>> template <class DF, class VF = hn::Vec<DF>>
@ -743,26 +785,33 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
size_t i = 0; size_t i = 0;
while (i + NF * 2 <= size) { while (i + NF * 2 <= size) {
VF out0a, out0b; VF out0a, out0b;
out0a = hn::Load(df, out + i + out_offsets[0]); out0a = hn::LoadU(df, out + i + out_offsets[0]);
VF scale0 = hn::Set(df, scales[0]); VF scale0 = hn::Set(df, scales[0]);
out0a = hn::Mul(out0a, scale0); out0a = hn::Mul(out0a, scale0);
out0b = hn::Load(df, out + i + NF + out_offsets[0]); out0b = hn::LoadU(df, out + i + NF + out_offsets[0]);
out0b = hn::Mul(out0b, scale0); out0b = hn::Mul(out0b, scale0);
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b); MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::Store(out0a, df, out + i + out_offsets[0]); hn::StoreU(out0a, df, out + i + out_offsets[0]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]); hn::StoreU(out0b, df, out + i + NF + out_offsets[0]);
i += NF * 2; i += NF * 2;
v_bf += 4 * NF * NF; v_bf += 4 * NF * NF;
} }
while (i < size) { if (i < size) {
float sum = out[i + out_offsets[0]] * scales[0]; VF out0a, out0b;
const BF16* HWY_RESTRICT v_local = v_bf; out0a = hn::LoadN(df, out + i + out_offsets[0], size - i);
for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF); VF scale0 = hn::Set(df, scales[0]);
++lane, v_local += 2 * NF) { out0a = hn::Mul(out0a, scale0);
sum += hwy::ConvertScalarTo<float>(*v_local) * c_mem[lane]; if (i + NF < size) {
out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF);
out0b = hn::Mul(out0b, scale0);
} else {
out0b = hn::Zero(df);
}
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::StoreN(out0a, df, out + i + out_offsets[0], size - i);
if (i + NF < size) {
hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF);
} }
++i;
++v_bf;
} }
} }

View File

@ -15,6 +15,8 @@ const char* ZoneName(Zones zone) {
return "FlashAttention.FlashAttention"; return "FlashAttention.FlashAttention";
case Zones::kFlashAttentionInclusive: case Zones::kFlashAttentionInclusive:
return "FlashAttention.Inclusive"; return "FlashAttention.Inclusive";
case Zones::kVitFlashAttentionInclusive:
return "Vit.FlashAttention.Inclusive";
case Zones::kFlashAttentionRmsNormAndPositionalEncoding: case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
return "FlashAttention.RMSNormAndPositionalEncoding"; return "FlashAttention.RMSNormAndPositionalEncoding";
case Zones::kFlashAttentionTileFlashAttention1: case Zones::kFlashAttentionTileFlashAttention1:
@ -106,6 +108,7 @@ const char* ZoneName(Zones zone) {
hwy::ProfilerFlags ZoneFlags(Zones zone) { hwy::ProfilerFlags ZoneFlags(Zones zone) {
switch (zone) { switch (zone) {
case Zones::kFlashAttentionInclusive: case Zones::kFlashAttentionInclusive:
case Zones::kVitFlashAttentionInclusive:
case Zones::kGenAttention: case Zones::kGenAttention:
case Zones::kGenAttentionComputeQKV: case Zones::kGenAttentionComputeQKV:
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive: case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:

View File

@ -13,6 +13,7 @@ namespace gcpp {
enum class Zones { // Keep sorted enum class Zones { // Keep sorted
kFlashAttentionFlashAttention, kFlashAttentionFlashAttention,
kFlashAttentionInclusive, kFlashAttentionInclusive,
kVitFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding, kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionTileFlashAttention1, kFlashAttentionTileFlashAttention1,
kFlashAttentionTileFlashAttention4, kFlashAttentionTileFlashAttention4,