Replaced attention in ViT with flash - 8x speedup of image tokenizer on AMD

PiperOrigin-RevId: 880877209
This commit is contained in:
Ray Smith 2026-03-09 08:45:29 -07:00 committed by Copybara-Service
parent 029cfd0b33
commit bea8b1cdbd
9 changed files with 303 additions and 148 deletions

View File

@ -555,6 +555,7 @@ cc_library(
":ops",
":tensor_stats",
":threading_context",
"@highway//:abort_header_only",
],
)
@ -678,6 +679,7 @@ cc_library(
":attention",
":basics",
":configs",
":flash_structs",
":gemma_args",
":kv_cache",
":mat",

View File

@ -76,8 +76,16 @@ struct AttentionActivations {
: batch_size * layer_config.heads,
allocator)),
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
vit_K_T(MatFactory(
"K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
layer_config.heads *
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
allocator, MatPadding::kPacked)),
vit_V_T(MatFactory(
"V2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
layer_config.heads *
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
allocator, MatPadding::kPacked)),
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
config.model_dim, allocator)),
// att is only valid for AttentionImpl::kOld.
@ -126,7 +134,6 @@ struct AttentionActivations {
q.AllocateAndAttachRowPtrs(row_ptrs);
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
q_T.AllocateAndAttachRowPtrs(row_ptrs);
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
}
@ -136,8 +143,7 @@ struct AttentionActivations {
// q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
// vit_K_T and vit_V_T stay seq_len!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
@ -167,8 +173,8 @@ struct AttentionActivations {
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
MatStorageT<float> vit_Q;
MatStorageT<float> vit_K;
MatStorageT<float> vit_C;
MatStorageT<KV_t> vit_K_T;
MatStorageT<KV_t> vit_V_T;
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
@ -214,8 +220,8 @@ struct AttentionActivationsPtrs {
q_bf = activations.q_bf;
q_T = activations.q_T;
vit_Q = activations.vit_Q;
vit_K = activations.vit_K;
vit_C = activations.vit_C;
vit_K_T = activations.vit_K_T;
vit_V_T = activations.vit_V_T;
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
@ -233,8 +239,7 @@ struct AttentionActivationsPtrs {
// q_T rows are always qkv_dim!
vit_Q.OverrideRows(batch_size);
// vit_K stays seq_len!
vit_C.OverrideRows(batch_size);
// vit_K_T and vit_V_T stay seq_len!
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
@ -267,8 +272,8 @@ struct AttentionActivationsPtrs {
MatPtrT<BF16> q_T;
MatPtrT<float> vit_Q;
MatPtrT<float> vit_K;
MatPtrT<float> vit_C;
MatPtrT<KV_t> vit_K_T;
MatPtrT<KV_t> vit_V_T;
// Output of RMSNorm before attention, size batch_size x model_dim.
MatPtrT<float> pre_att_rms_out;

View File

@ -2260,3 +2260,21 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(DispatchTileFlashAttention148);
void DispatchDispatchTileFlashAttention148(
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
AttentionImpl attention_impl) {
HWY_DYNAMIC_DISPATCH(DispatchTileFlashAttention148)(
params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker,
attention_impl);
}
} // namespace gcpp
#endif

View File

@ -42,14 +42,6 @@ namespace gcpp {
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
@ -83,6 +75,13 @@ HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
#undef GEMMA_DECL_FLASH_ATTENTION
void DispatchDispatchTileFlashAttention148(
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
const MatPtrT<KV_t>& v, const size_t layer_idx,
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
AttentionImpl attention_impl);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_

View File

@ -544,8 +544,6 @@ void TestAttentionMultipleTokens() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
@ -590,8 +588,6 @@ void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
@ -763,8 +759,6 @@ void TestAttentionMultipleTokensBF16() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
@ -807,8 +801,6 @@ void TestAttentionMultipleTokensInt8() {
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);

View File

@ -20,6 +20,7 @@
#include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/flash_structs.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
@ -41,6 +42,8 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/attention.h"
#include "gemma/flash_attention.h"
#include "gemma/gemma-inl.h"
#include "ops/ops-inl.h"
@ -68,107 +71,194 @@ class VitAttention {
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
}
// TODO(philculliton): transition fully to MatMul.
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len =
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmaxMatrix");
MatPtrT<float>& Q = activations_.attention.vit_Q;
MatPtrT<float>& K = activations_.attention.vit_K;
MatPtrT<float>& C = activations_.attention.vit_C;
// Initialize att_out to zero prior to head loop.
ZeroInit(activations_.attention.att_out);
for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, caller1_,
[&](uint64_t task, size_t worker) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed
// working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
// Applies the query scale to the query and converts to QType.
template <typename QKVType, typename QType>
void ScaleQuery(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
const size_t heads, const size_t qkv_dim,
const float query_scale, MatPtrT<QType>& q_output) {
ParallelFor(Parallelism::kWithinCluster, heads, env_.ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t head, size_t worker) {
size_t q_offset = head * qkv_dim;
for (size_t token = 0; token < num_tokens; ++token) {
const float* HWY_RESTRICT src_q =
qkv.Row(token) + q_offset * 3;
QType* HWY_RESTRICT dst_q = q_output.Row(token) + q_offset;
for (size_t i = 0; i < qkv_dim; ++i) {
dst_q[i] = hwy::ConvertScalarTo<QType>(
hwy::ConvertScalarTo<float>(src_q[i]) * query_scale);
}
}
});
}
pool_.Run(
0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
CallMatMul(Q, K, nullptr, env_, C);
pool_.Run(0, num_tokens_, caller3_,
[&](uint64_t task, size_t worker)
HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); });
pool_.Run(
0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
// Transposes K and V and converts to KVType.
template <typename QKVType, typename KVType>
void TransposeKAndV(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
const size_t heads, const size_t qkv_dim,
MatPtrT<KVType>& k_output, MatPtrT<KVType>& v_output) {
using DF = hn::ScalableTag<float>;
const DF df;
const size_t kNF = hn::Lanes(df);
const size_t kNumTokensH = hwy::DivCeil(num_tokens, 2 * kNF);
const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
ParallelFor(
Parallelism::kWithinCluster, heads, env_.ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t head, size_t worker) {
const size_t qkv_offset = head * 3 * qkv_dim;
const size_t k_or_v_offset = head * 2 * kNF * kRoundedKVDim;
for (size_t token_h = 0; token_h < kNumTokensH; ++token_h) {
KVType* HWY_RESTRICT dst_k = k_output.Row(token_h);
KVType* HWY_RESTRICT dst_v = v_output.Row(token_h);
size_t dst_k_index = k_or_v_offset;
for (size_t q = 0; q < qkv_dim; q += 2) {
for (size_t token_l = 0; token_l < 2 * kNF;
++token_l, dst_k_index += 2) {
const QKVType* HWY_RESTRICT src_k =
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset + qkv_dim;
dst_k[dst_k_index] = hwy::ConvertScalarTo<KVType>(src_k[q]);
dst_k[dst_k_index + 1] =
hwy::ConvertScalarTo<KVType>(src_k[q + 1]);
}
}
});
size_t dst_v_index = k_or_v_offset;
for (size_t q = 0; q < qkv_dim; q += 2 * kNF) {
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
const QKVType* HWY_RESTRICT src_v =
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset +
qkv_dim * 2;
if (q + 2 * kNF <= qkv_dim) {
for (size_t q_l = 0; q_l < 2 * kNF; ++q_l) {
dst_v[dst_v_index++] =
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
}
} else {
for (size_t q_l = 0; q_l < qkv_dim - q; ++q_l) {
dst_v[dst_v_index++] =
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
}
}
}
}
// Zero out the padding area.
// In the loops above, the dst_k loop has written 2kNF x 2
// consecutive elements for each q +=2, and the dst_v loop has
// written 2kNF x 2kNF consecutive elements for each q += 2 * kNF.
// Both of them therefore write 2kNF elements for each increment of
// q, so we can combine both into a single loop for the padding.
// This could be further simplified by writing a zero vector.
for (size_t q = qkv_dim; q < kRoundedKVDim; ++q) {
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
dst_k[dst_k_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
dst_v[dst_v_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
}
}
}
});
}
// Computes the flash attention parameters. This is mostly about deciding on
// the tile sizes and filling the param structs with the correct offsets.
template <typename QType, typename KVType>
void ComputeParams(const uint32_t num_tokens, const size_t seq_len,
const size_t heads, const uint32_t qkv_dim,
const MatPtrT<QType>& q, const MatPtrT<KVType>& k,
const MatPtrT<KVType>& v, const MatPtrT<float>& att_out,
std::vector<Tile148Params>& flash_params) {
flash_params.clear();
for (uint32_t head = 0; head < heads; ++head) {
uint32_t token = 0;
while (token + k8xNFVTileSize <= num_tokens) {
flash_params.push_back(Tile148Params{
.v_tile_size = k8xNFVTileSize,
.qi_index = token,
.kv_head = head,
});
token += k8xNFVTileSize;
}
if (token + k4xNFVTileSize <= num_tokens) {
flash_params.push_back(Tile148Params{
.v_tile_size = k4xNFVTileSize,
.qi_index = token,
.kv_head = head,
});
token += k4xNFVTileSize;
}
while (token < num_tokens) {
flash_params.push_back(Tile148Params{
.v_tile_size = 1,
.qi_index = token,
.kv_head = head,
});
token += 1;
}
}
for (auto& param : flash_params) {
param.min_start_pos = 0;
param.max_last_pos = num_tokens - 1;
for (size_t i = 0; i < param.v_tile_size; ++i) {
param.q_offsets[i] =
q.Row(param.qi_index + i) + param.kv_head * qkv_dim - q.Row(0);
param.out_offsets[i] = att_out.Row(param.qi_index + i) +
param.kv_head * qkv_dim - att_out.Row(0);
param.start_pos[i] = 0;
param.last_pos[i] = num_tokens - 1;
}
}
}
HWY_NOINLINE void DotSoftmaxWeightedSum() {
// Runs the flash attention algorithm on Q, K, V.
HWY_NOINLINE void FlashAttention() {
GCPP_ZONE(env_.ctx, 0, Zones::kVitFlashAttentionInclusive);
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len =
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
const size_t kNF = FloatsPerVector();
const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
auto& attn = activations_.attention;
const size_t seq_len = static_cast<size_t>(attn.div_seq_len.GetDivisor());
if (attn.vit_K_T.Rows() >= seq_len) {
attn.vit_K_T.ReshapePackedRowsToCols(2 * kNF);
attn.vit_V_T.ReshapePackedRowsToCols(2 * kNF);
}
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
ScaleQuery(attn.q, num_tokens_, heads, qkv_dim, query_scale, attn.q_bf);
TransposeKAndV(attn.q, num_tokens_, heads, qkv_dim, attn.vit_K_T,
attn.vit_V_T);
ComputeParams(num_tokens_, seq_len, heads, qkv_dim, attn.q_bf, attn.vit_K_T,
attn.vit_V_T, attn.att_out, attn.flash_params);
size_t num_tasks = attn.flash_params.size();
// Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_, caller1_,
[&](uint64_t task, size_t worker) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.attention.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(Logits(head_att, seq_len), env_.ctx, worker);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.attention.att_out.Row(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
// For each param, compute fused flash Q.K, softmax and weighted V.
const auto func = [&, &ctx = env_.ctx](const size_t task,
size_t worker) HWY_ATTR {
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
auto& param = attn.flash_params[task];
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
kRoundedKVDim * 2 * kNF));
kT.SetPtr(attn.vit_K_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
attn.vit_K_T.Stride());
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
kRoundedKVDim * 2 * kNF));
vT.SetPtr(attn.vit_V_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
attn.vit_V_T.Stride());
DispatchDispatchTileFlashAttention148(
param, attn.q_bf, kT, vT, /*layer_idx=*/0, attn, attn.att_out,
qkv_dim, ctx, worker, /*attention_impl=*/AttentionImpl::kFlash);
};
{
PROFILER_ZONE("Gen.VitFlashAttention.ForkJoin");
// Full parallelism is helpful, SmallParallelFor is insufficient.
HierarchicalParallelFor(num_tasks, env_.ctx, Callers::kFlashAttention,
func);
}
}
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_.vit.attn_out_b.PackedScale1();
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
@ -193,11 +283,7 @@ class VitAttention {
HWY_INLINE void operator()() {
ComputeQKV();
if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) {
DotSoftmaxWeightedSumMatrix();
} else {
DotSoftmaxWeightedSum();
}
FlashAttention();
SumHeads();
}

View File

@ -669,10 +669,10 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
size_t i = 0;
while (i + NF * 2 <= size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::Load(df, out + i + out_offsets[0]);
out1a = hn::Load(df, out + i + out_offsets[1]);
out2a = hn::Load(df, out + i + out_offsets[2]);
out3a = hn::Load(df, out + i + out_offsets[3]);
out0a = hn::LoadU(df, out + i + out_offsets[0]);
out1a = hn::LoadU(df, out + i + out_offsets[1]);
out2a = hn::LoadU(df, out + i + out_offsets[2]);
out3a = hn::LoadU(df, out + i + out_offsets[3]);
VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]);
@ -681,28 +681,70 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
out0b = hn::LoadU(df, out + i + NF + out_offsets[0]);
out1b = hn::LoadU(df, out + i + NF + out_offsets[1]);
out2b = hn::LoadU(df, out + i + NF + out_offsets[2]);
out3b = hn::LoadU(df, out + i + NF + out_offsets[3]);
out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3);
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out1a, df, out + i + out_offsets[1]);
hn::Store(out2a, df, out + i + out_offsets[2]);
hn::Store(out3a, df, out + i + out_offsets[3]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
hn::StoreU(out0a, df, out + i + out_offsets[0]);
hn::StoreU(out1a, df, out + i + out_offsets[1]);
hn::StoreU(out2a, df, out + i + out_offsets[2]);
hn::StoreU(out3a, df, out + i + out_offsets[3]);
hn::StoreU(out0b, df, out + i + NF + out_offsets[0]);
hn::StoreU(out1b, df, out + i + NF + out_offsets[1]);
hn::StoreU(out2b, df, out + i + NF + out_offsets[2]);
hn::StoreU(out3b, df, out + i + NF + out_offsets[3]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
HWY_DASSERT(size == i);
if (i < size) {
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
out0a = hn::LoadN(df, out + i + out_offsets[0], size - i);
out1a = hn::LoadN(df, out + i + out_offsets[1], size - i);
out2a = hn::LoadN(df, out + i + out_offsets[2], size - i);
out3a = hn::LoadN(df, out + i + out_offsets[3], size - i);
VF scale0 = hn::Set(df, scales[0]);
VF scale1 = hn::Set(df, scales[1]);
VF scale2 = hn::Set(df, scales[2]);
VF scale3 = hn::Set(df, scales[3]);
out0a = hn::Mul(out0a, scale0);
out1a = hn::Mul(out1a, scale1);
out2a = hn::Mul(out2a, scale2);
out3a = hn::Mul(out3a, scale3);
if (i + NF < size) {
out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF);
out1b = hn::LoadN(df, out + i + NF + out_offsets[1], size - i - NF);
out2b = hn::LoadN(df, out + i + NF + out_offsets[2], size - i - NF);
out3b = hn::LoadN(df, out + i + NF + out_offsets[3], size - i - NF);
out0b = hn::Mul(out0b, scale0);
out1b = hn::Mul(out1b, scale1);
out2b = hn::Mul(out2b, scale2);
out3b = hn::Mul(out3b, scale3);
} else {
out0b = hn::Zero(df);
out1b = hn::Zero(df);
out2b = hn::Zero(df);
out3b = hn::Zero(df);
}
// Note that v_bf is always padded, so we can always load 2 * NF elements.
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
out2a, out3a, out0b, out1b, out2b, out3b);
hn::StoreN(out0a, df, out + i + out_offsets[0], size - i);
hn::StoreN(out1a, df, out + i + out_offsets[1], size - i);
hn::StoreN(out2a, df, out + i + out_offsets[2], size - i);
hn::StoreN(out3a, df, out + i + out_offsets[3], size - i);
if (i + NF < size) {
hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF);
hn::StoreN(out1b, df, out + i + NF + out_offsets[1], size - i - NF);
hn::StoreN(out2b, df, out + i + NF + out_offsets[2], size - i - NF);
hn::StoreN(out3b, df, out + i + NF + out_offsets[3], size - i - NF);
}
}
}
template <class DF, class VF = hn::Vec<DF>>
@ -743,26 +785,33 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
size_t i = 0;
while (i + NF * 2 <= size) {
VF out0a, out0b;
out0a = hn::Load(df, out + i + out_offsets[0]);
out0a = hn::LoadU(df, out + i + out_offsets[0]);
VF scale0 = hn::Set(df, scales[0]);
out0a = hn::Mul(out0a, scale0);
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
out0b = hn::LoadU(df, out + i + NF + out_offsets[0]);
out0b = hn::Mul(out0b, scale0);
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::Store(out0a, df, out + i + out_offsets[0]);
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
hn::StoreU(out0a, df, out + i + out_offsets[0]);
hn::StoreU(out0b, df, out + i + NF + out_offsets[0]);
i += NF * 2;
v_bf += 4 * NF * NF;
}
while (i < size) {
float sum = out[i + out_offsets[0]] * scales[0];
const BF16* HWY_RESTRICT v_local = v_bf;
for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF);
++lane, v_local += 2 * NF) {
sum += hwy::ConvertScalarTo<float>(*v_local) * c_mem[lane];
if (i < size) {
VF out0a, out0b;
out0a = hn::LoadN(df, out + i + out_offsets[0], size - i);
VF scale0 = hn::Set(df, scales[0]);
out0a = hn::Mul(out0a, scale0);
if (i + NF < size) {
out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF);
out0b = hn::Mul(out0b, scale0);
} else {
out0b = hn::Zero(df);
}
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
hn::StoreN(out0a, df, out + i + out_offsets[0], size - i);
if (i + NF < size) {
hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF);
}
++i;
++v_bf;
}
}

View File

@ -15,6 +15,8 @@ const char* ZoneName(Zones zone) {
return "FlashAttention.FlashAttention";
case Zones::kFlashAttentionInclusive:
return "FlashAttention.Inclusive";
case Zones::kVitFlashAttentionInclusive:
return "Vit.FlashAttention.Inclusive";
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
return "FlashAttention.RMSNormAndPositionalEncoding";
case Zones::kFlashAttentionTileFlashAttention1:
@ -106,6 +108,7 @@ const char* ZoneName(Zones zone) {
hwy::ProfilerFlags ZoneFlags(Zones zone) {
switch (zone) {
case Zones::kFlashAttentionInclusive:
case Zones::kVitFlashAttentionInclusive:
case Zones::kGenAttention:
case Zones::kGenAttentionComputeQKV:
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:

View File

@ -13,6 +13,7 @@ namespace gcpp {
enum class Zones { // Keep sorted
kFlashAttentionFlashAttention,
kFlashAttentionInclusive,
kVitFlashAttentionInclusive,
kFlashAttentionRmsNormAndPositionalEncoding,
kFlashAttentionTileFlashAttention1,
kFlashAttentionTileFlashAttention4,