mirror of https://github.com/google/gemma.cpp.git
Replaced attention in ViT with flash - 8x speedup of image tokenizer on AMD
PiperOrigin-RevId: 880877209
This commit is contained in:
parent
029cfd0b33
commit
bea8b1cdbd
|
|
@ -555,6 +555,7 @@ cc_library(
|
|||
":ops",
|
||||
":tensor_stats",
|
||||
":threading_context",
|
||||
"@highway//:abort_header_only",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -678,6 +679,7 @@ cc_library(
|
|||
":attention",
|
||||
":basics",
|
||||
":configs",
|
||||
":flash_structs",
|
||||
":gemma_args",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
|
|
|
|||
|
|
@ -76,8 +76,16 @@ struct AttentionActivations {
|
|||
: batch_size * layer_config.heads,
|
||||
allocator)),
|
||||
vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)),
|
||||
vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)),
|
||||
vit_C(MatFactory("C2", batch_size, seq_len, allocator)),
|
||||
vit_K_T(MatFactory(
|
||||
"K2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
|
||||
layer_config.heads *
|
||||
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
|
||||
allocator, MatPadding::kPacked)),
|
||||
vit_V_T(MatFactory(
|
||||
"V2_T", hwy::RoundUpTo(seq_len, kMaxBF16PerVector),
|
||||
layer_config.heads *
|
||||
hwy::RoundUpTo(layer_config.qkv_dim, kMaxBF16PerVector),
|
||||
allocator, MatPadding::kPacked)),
|
||||
pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size,
|
||||
config.model_dim, allocator)),
|
||||
// att is only valid for AttentionImpl::kOld.
|
||||
|
|
@ -126,7 +134,6 @@ struct AttentionActivations {
|
|||
q.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_bf.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
q_T.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
vit_C.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
att_sums.AllocateAndAttachRowPtrs(row_ptrs);
|
||||
}
|
||||
|
||||
|
|
@ -136,8 +143,7 @@ struct AttentionActivations {
|
|||
// q_T rows are always qkv_dim!
|
||||
|
||||
vit_Q.OverrideRows(batch_size);
|
||||
// vit_K stays seq_len!
|
||||
vit_C.OverrideRows(batch_size);
|
||||
// vit_K_T and vit_V_T stay seq_len!
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
|
|
@ -167,8 +173,8 @@ struct AttentionActivations {
|
|||
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
|
||||
|
||||
MatStorageT<float> vit_Q;
|
||||
MatStorageT<float> vit_K;
|
||||
MatStorageT<float> vit_C;
|
||||
MatStorageT<KV_t> vit_K_T;
|
||||
MatStorageT<KV_t> vit_V_T;
|
||||
|
||||
MatStorageT<float> pre_att_rms_out;
|
||||
MatStorageT<float> att; // attention vector
|
||||
|
|
@ -214,8 +220,8 @@ struct AttentionActivationsPtrs {
|
|||
q_bf = activations.q_bf;
|
||||
q_T = activations.q_T;
|
||||
vit_Q = activations.vit_Q;
|
||||
vit_K = activations.vit_K;
|
||||
vit_C = activations.vit_C;
|
||||
vit_K_T = activations.vit_K_T;
|
||||
vit_V_T = activations.vit_V_T;
|
||||
pre_att_rms_out = activations.pre_att_rms_out;
|
||||
att = activations.att;
|
||||
att_out = activations.att_out;
|
||||
|
|
@ -233,8 +239,7 @@ struct AttentionActivationsPtrs {
|
|||
// q_T rows are always qkv_dim!
|
||||
|
||||
vit_Q.OverrideRows(batch_size);
|
||||
// vit_K stays seq_len!
|
||||
vit_C.OverrideRows(batch_size);
|
||||
// vit_K_T and vit_V_T stay seq_len!
|
||||
|
||||
pre_att_rms_out.OverrideRows(batch_size);
|
||||
att.OverrideRows(batch_size);
|
||||
|
|
@ -267,8 +272,8 @@ struct AttentionActivationsPtrs {
|
|||
MatPtrT<BF16> q_T;
|
||||
|
||||
MatPtrT<float> vit_Q;
|
||||
MatPtrT<float> vit_K;
|
||||
MatPtrT<float> vit_C;
|
||||
MatPtrT<KV_t> vit_K_T;
|
||||
MatPtrT<KV_t> vit_V_T;
|
||||
|
||||
// Output of RMSNorm before attention, size batch_size x model_dim.
|
||||
MatPtrT<float> pre_att_rms_out;
|
||||
|
|
|
|||
|
|
@ -2260,3 +2260,21 @@ void FlashAttention(const size_t num_tokens, const size_t target_parallelism,
|
|||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_EXPORT(DispatchTileFlashAttention148);
|
||||
|
||||
void DispatchDispatchTileFlashAttention148(
|
||||
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
|
||||
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
|
||||
AttentionImpl attention_impl) {
|
||||
HWY_DYNAMIC_DISPATCH(DispatchTileFlashAttention148)(
|
||||
params, q, k, v, layer_idx, activations, att_out, qkv_dim, ctx, worker,
|
||||
attention_impl);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -42,14 +42,6 @@ namespace gcpp {
|
|||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const BF16* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
|
|
@ -83,6 +75,13 @@ HWY_VISIT_TARGETS(GEMMA_DECL_FLASH_ATTENTION)
|
|||
|
||||
#undef GEMMA_DECL_FLASH_ATTENTION
|
||||
|
||||
void DispatchDispatchTileFlashAttention148(
|
||||
Tile148Params& params, const MatPtrT<BF16>& q, const MatPtrT<KV_t>& k,
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx,
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out,
|
||||
size_t qkv_dim, ThreadingContext& ctx, const size_t worker,
|
||||
AttentionImpl attention_impl);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FLASH_ATTENTION_H_
|
||||
|
|
|
|||
|
|
@ -544,8 +544,6 @@ void TestAttentionMultipleTokens() {
|
|||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
|
|
@ -590,8 +588,6 @@ void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
|
|||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
|
|
@ -763,8 +759,6 @@ void TestAttentionMultipleTokensBF16() {
|
|||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
|
|
@ -807,8 +801,6 @@ void TestAttentionMultipleTokensInt8() {
|
|||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
|
|
|
|||
266
gemma/vit.cc
266
gemma/vit.cc
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/flash_structs.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
@ -41,6 +42,8 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/attention.h"
|
||||
#include "gemma/flash_attention.h"
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
|
|
@ -68,107 +71,194 @@ class VitAttention {
|
|||
layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv);
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSumMatrix() {
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmaxMatrix");
|
||||
|
||||
MatPtrT<float>& Q = activations_.attention.vit_Q;
|
||||
MatPtrT<float>& K = activations_.attention.vit_K;
|
||||
MatPtrT<float>& C = activations_.attention.vit_C;
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
ZeroInit(activations_.attention.att_out);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
pool_.Run(0, num_tokens_, caller1_,
|
||||
[&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
const size_t token = task;
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
// TODO: shift to MatMul with A.scale once MatMul is confirmed
|
||||
// working
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
|
||||
// Applies the query scale to the query and converts to QType.
|
||||
template <typename QKVType, typename QType>
|
||||
void ScaleQuery(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
|
||||
const size_t heads, const size_t qkv_dim,
|
||||
const float query_scale, MatPtrT<QType>& q_output) {
|
||||
ParallelFor(Parallelism::kWithinCluster, heads, env_.ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t head, size_t worker) {
|
||||
size_t q_offset = head * qkv_dim;
|
||||
for (size_t token = 0; token < num_tokens; ++token) {
|
||||
const float* HWY_RESTRICT src_q =
|
||||
qkv.Row(token) + q_offset * 3;
|
||||
QType* HWY_RESTRICT dst_q = q_output.Row(token) + q_offset;
|
||||
for (size_t i = 0; i < qkv_dim; ++i) {
|
||||
dst_q[i] = hwy::ConvertScalarTo<QType>(
|
||||
hwy::ConvertScalarTo<float>(src_q[i]) * query_scale);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pool_.Run(
|
||||
0, seq_len, caller2_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t seq_idx = task;
|
||||
float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) +
|
||||
head * 3 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
|
||||
});
|
||||
|
||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
CallMatMul(Q, K, nullptr, env_, C);
|
||||
|
||||
pool_.Run(0, num_tokens_, caller3_,
|
||||
[&](uint64_t task, size_t worker)
|
||||
HWY_ATTR { Softmax(C.RowSpan(task), env_.ctx, worker); });
|
||||
|
||||
pool_.Run(
|
||||
0, num_tokens_, caller4_, [&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
size_t token = task;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
|
||||
// Transposes K and V and converts to KVType.
|
||||
template <typename QKVType, typename KVType>
|
||||
void TransposeKAndV(const MatPtrT<QKVType>& qkv, const size_t num_tokens,
|
||||
const size_t heads, const size_t qkv_dim,
|
||||
MatPtrT<KVType>& k_output, MatPtrT<KVType>& v_output) {
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
const size_t kNF = hn::Lanes(df);
|
||||
const size_t kNumTokensH = hwy::DivCeil(num_tokens, 2 * kNF);
|
||||
const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
|
||||
ParallelFor(
|
||||
Parallelism::kWithinCluster, heads, env_.ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t head, size_t worker) {
|
||||
const size_t qkv_offset = head * 3 * qkv_dim;
|
||||
const size_t k_or_v_offset = head * 2 * kNF * kRoundedKVDim;
|
||||
for (size_t token_h = 0; token_h < kNumTokensH; ++token_h) {
|
||||
KVType* HWY_RESTRICT dst_k = k_output.Row(token_h);
|
||||
KVType* HWY_RESTRICT dst_v = v_output.Row(token_h);
|
||||
size_t dst_k_index = k_or_v_offset;
|
||||
for (size_t q = 0; q < qkv_dim; q += 2) {
|
||||
for (size_t token_l = 0; token_l < 2 * kNF;
|
||||
++token_l, dst_k_index += 2) {
|
||||
const QKVType* HWY_RESTRICT src_k =
|
||||
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset + qkv_dim;
|
||||
dst_k[dst_k_index] = hwy::ConvertScalarTo<KVType>(src_k[q]);
|
||||
dst_k[dst_k_index + 1] =
|
||||
hwy::ConvertScalarTo<KVType>(src_k[q + 1]);
|
||||
}
|
||||
}
|
||||
});
|
||||
size_t dst_v_index = k_or_v_offset;
|
||||
for (size_t q = 0; q < qkv_dim; q += 2 * kNF) {
|
||||
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
|
||||
const QKVType* HWY_RESTRICT src_v =
|
||||
qkv.Row(token_h * 2 * kNF + token_l) + qkv_offset +
|
||||
qkv_dim * 2;
|
||||
if (q + 2 * kNF <= qkv_dim) {
|
||||
for (size_t q_l = 0; q_l < 2 * kNF; ++q_l) {
|
||||
dst_v[dst_v_index++] =
|
||||
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
|
||||
}
|
||||
} else {
|
||||
for (size_t q_l = 0; q_l < qkv_dim - q; ++q_l) {
|
||||
dst_v[dst_v_index++] =
|
||||
hwy::ConvertScalarTo<KVType>(src_v[q + q_l]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Zero out the padding area.
|
||||
// In the loops above, the dst_k loop has written 2kNF x 2
|
||||
// consecutive elements for each q +=2, and the dst_v loop has
|
||||
// written 2kNF x 2kNF consecutive elements for each q += 2 * kNF.
|
||||
// Both of them therefore write 2kNF elements for each increment of
|
||||
// q, so we can combine both into a single loop for the padding.
|
||||
// This could be further simplified by writing a zero vector.
|
||||
for (size_t q = qkv_dim; q < kRoundedKVDim; ++q) {
|
||||
for (size_t token_l = 0; token_l < 2 * kNF; ++token_l) {
|
||||
dst_k[dst_k_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
|
||||
dst_v[dst_v_index++] = hwy::ConvertScalarTo<KVType>(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Computes the flash attention parameters. This is mostly about deciding on
|
||||
// the tile sizes and filling the param structs with the correct offsets.
|
||||
template <typename QType, typename KVType>
|
||||
void ComputeParams(const uint32_t num_tokens, const size_t seq_len,
|
||||
const size_t heads, const uint32_t qkv_dim,
|
||||
const MatPtrT<QType>& q, const MatPtrT<KVType>& k,
|
||||
const MatPtrT<KVType>& v, const MatPtrT<float>& att_out,
|
||||
std::vector<Tile148Params>& flash_params) {
|
||||
flash_params.clear();
|
||||
for (uint32_t head = 0; head < heads; ++head) {
|
||||
uint32_t token = 0;
|
||||
while (token + k8xNFVTileSize <= num_tokens) {
|
||||
flash_params.push_back(Tile148Params{
|
||||
.v_tile_size = k8xNFVTileSize,
|
||||
.qi_index = token,
|
||||
.kv_head = head,
|
||||
});
|
||||
token += k8xNFVTileSize;
|
||||
}
|
||||
if (token + k4xNFVTileSize <= num_tokens) {
|
||||
flash_params.push_back(Tile148Params{
|
||||
.v_tile_size = k4xNFVTileSize,
|
||||
.qi_index = token,
|
||||
.kv_head = head,
|
||||
});
|
||||
token += k4xNFVTileSize;
|
||||
}
|
||||
while (token < num_tokens) {
|
||||
flash_params.push_back(Tile148Params{
|
||||
.v_tile_size = 1,
|
||||
.qi_index = token,
|
||||
.kv_head = head,
|
||||
});
|
||||
token += 1;
|
||||
}
|
||||
}
|
||||
for (auto& param : flash_params) {
|
||||
param.min_start_pos = 0;
|
||||
param.max_last_pos = num_tokens - 1;
|
||||
for (size_t i = 0; i < param.v_tile_size; ++i) {
|
||||
param.q_offsets[i] =
|
||||
q.Row(param.qi_index + i) + param.kv_head * qkv_dim - q.Row(0);
|
||||
param.out_offsets[i] = att_out.Row(param.qi_index + i) +
|
||||
param.kv_head * qkv_dim - att_out.Row(0);
|
||||
param.start_pos[i] = 0;
|
||||
param.last_pos[i] = num_tokens - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HWY_NOINLINE void DotSoftmaxWeightedSum() {
|
||||
// Runs the flash attention algorithm on Q, K, V.
|
||||
HWY_NOINLINE void FlashAttention() {
|
||||
GCPP_ZONE(env_.ctx, 0, Zones::kVitFlashAttentionInclusive);
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations_.attention.div_seq_len.GetDivisor());
|
||||
const size_t kNF = FloatsPerVector();
|
||||
const size_t kRoundedKVDim = hwy::RoundUpTo(qkv_dim, 2 * kNF);
|
||||
auto& attn = activations_.attention;
|
||||
const size_t seq_len = static_cast<size_t>(attn.div_seq_len.GetDivisor());
|
||||
if (attn.vit_K_T.Rows() >= seq_len) {
|
||||
attn.vit_K_T.ReshapePackedRowsToCols(2 * kNF);
|
||||
attn.vit_V_T.ReshapePackedRowsToCols(2 * kNF);
|
||||
}
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
ScaleQuery(attn.q, num_tokens_, heads, qkv_dim, query_scale, attn.q_bf);
|
||||
TransposeKAndV(attn.q, num_tokens_, heads, qkv_dim, attn.vit_K_T,
|
||||
attn.vit_V_T);
|
||||
ComputeParams(num_tokens_, seq_len, heads, qkv_dim, attn.q_bf, attn.vit_K_T,
|
||||
attn.vit_V_T, attn.att_out, attn.flash_params);
|
||||
size_t num_tasks = attn.flash_params.size();
|
||||
|
||||
// Compute Q.K, softmax, and weighted V.
|
||||
pool_.Run(0, layer_config_.heads * num_tokens_, caller1_,
|
||||
[&](uint64_t task, size_t worker) HWY_ATTR {
|
||||
const size_t head = task % layer_config_.heads;
|
||||
const size_t token = task / layer_config_.heads;
|
||||
// Compute Q.K scores, which are "logits" stored in head_att.
|
||||
float* HWY_RESTRICT q =
|
||||
activations_.attention.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.attention.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT k = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + qkv_dim;
|
||||
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
|
||||
}
|
||||
// SoftMax yields "probabilities" in head_att.
|
||||
Softmax(Logits(head_att, seq_len), env_.ctx, worker);
|
||||
// Compute weighted sum of v into att_out.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations_.attention.att_out.Row(token) + head * qkv_dim;
|
||||
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT v = activations_.attention.q.Row(i) +
|
||||
head * 3 * qkv_dim + 2 * qkv_dim;
|
||||
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
|
||||
}
|
||||
});
|
||||
// For each param, compute fused flash Q.K, softmax and weighted V.
|
||||
const auto func = [&, &ctx = env_.ctx](const size_t task,
|
||||
size_t worker) HWY_ATTR {
|
||||
GCPP_ZONE(ctx, worker, Zones::kFlashAttentionFlashAttention);
|
||||
auto& param = attn.flash_params[task];
|
||||
MatPtrT<KV_t> kT("k_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
|
||||
kRoundedKVDim * 2 * kNF));
|
||||
kT.SetPtr(attn.vit_K_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
|
||||
attn.vit_K_T.Stride());
|
||||
MatPtrT<KV_t> vT("v_T_view", Extents2D(hwy::DivCeil(seq_len, 2 * kNF),
|
||||
kRoundedKVDim * 2 * kNF));
|
||||
vT.SetPtr(attn.vit_V_T.Row(0) + param.kv_head * kRoundedKVDim * 2 * kNF,
|
||||
attn.vit_V_T.Stride());
|
||||
DispatchDispatchTileFlashAttention148(
|
||||
param, attn.q_bf, kT, vT, /*layer_idx=*/0, attn, attn.att_out,
|
||||
qkv_dim, ctx, worker, /*attention_impl=*/AttentionImpl::kFlash);
|
||||
};
|
||||
|
||||
{
|
||||
PROFILER_ZONE("Gen.VitFlashAttention.ForkJoin");
|
||||
// Full parallelism is helpful, SmallParallelFor is insufficient.
|
||||
HierarchicalParallelFor(num_tasks, env_.ctx, Callers::kFlashAttention,
|
||||
func);
|
||||
}
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
|
||||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||
auto* bias = layer_.vit.attn_out_b.PackedScale1();
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
|
|
@ -193,11 +283,7 @@ class VitAttention {
|
|||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
DotSoftmaxWeightedSumMatrix();
|
||||
} else {
|
||||
DotSoftmaxWeightedSum();
|
||||
}
|
||||
FlashAttention();
|
||||
SumHeads();
|
||||
}
|
||||
|
||||
|
|
|
|||
107
ops/ops-inl.h
107
ops/ops-inl.h
|
|
@ -669,10 +669,10 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
|
|||
size_t i = 0;
|
||||
while (i + NF * 2 <= size) {
|
||||
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
|
||||
out0a = hn::Load(df, out + i + out_offsets[0]);
|
||||
out1a = hn::Load(df, out + i + out_offsets[1]);
|
||||
out2a = hn::Load(df, out + i + out_offsets[2]);
|
||||
out3a = hn::Load(df, out + i + out_offsets[3]);
|
||||
out0a = hn::LoadU(df, out + i + out_offsets[0]);
|
||||
out1a = hn::LoadU(df, out + i + out_offsets[1]);
|
||||
out2a = hn::LoadU(df, out + i + out_offsets[2]);
|
||||
out3a = hn::LoadU(df, out + i + out_offsets[3]);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
VF scale1 = hn::Set(df, scales[1]);
|
||||
VF scale2 = hn::Set(df, scales[2]);
|
||||
|
|
@ -681,28 +681,70 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
|
|||
out1a = hn::Mul(out1a, scale1);
|
||||
out2a = hn::Mul(out2a, scale2);
|
||||
out3a = hn::Mul(out3a, scale3);
|
||||
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
|
||||
out1b = hn::Load(df, out + i + NF + out_offsets[1]);
|
||||
out2b = hn::Load(df, out + i + NF + out_offsets[2]);
|
||||
out3b = hn::Load(df, out + i + NF + out_offsets[3]);
|
||||
out0b = hn::LoadU(df, out + i + NF + out_offsets[0]);
|
||||
out1b = hn::LoadU(df, out + i + NF + out_offsets[1]);
|
||||
out2b = hn::LoadU(df, out + i + NF + out_offsets[2]);
|
||||
out3b = hn::LoadU(df, out + i + NF + out_offsets[3]);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
out1b = hn::Mul(out1b, scale1);
|
||||
out2b = hn::Mul(out2b, scale2);
|
||||
out3b = hn::Mul(out3b, scale3);
|
||||
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
|
||||
out2a, out3a, out0b, out1b, out2b, out3b);
|
||||
hn::Store(out0a, df, out + i + out_offsets[0]);
|
||||
hn::Store(out1a, df, out + i + out_offsets[1]);
|
||||
hn::Store(out2a, df, out + i + out_offsets[2]);
|
||||
hn::Store(out3a, df, out + i + out_offsets[3]);
|
||||
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
|
||||
hn::Store(out1b, df, out + i + NF + out_offsets[1]);
|
||||
hn::Store(out2b, df, out + i + NF + out_offsets[2]);
|
||||
hn::Store(out3b, df, out + i + NF + out_offsets[3]);
|
||||
hn::StoreU(out0a, df, out + i + out_offsets[0]);
|
||||
hn::StoreU(out1a, df, out + i + out_offsets[1]);
|
||||
hn::StoreU(out2a, df, out + i + out_offsets[2]);
|
||||
hn::StoreU(out3a, df, out + i + out_offsets[3]);
|
||||
hn::StoreU(out0b, df, out + i + NF + out_offsets[0]);
|
||||
hn::StoreU(out1b, df, out + i + NF + out_offsets[1]);
|
||||
hn::StoreU(out2b, df, out + i + NF + out_offsets[2]);
|
||||
hn::StoreU(out3b, df, out + i + NF + out_offsets[3]);
|
||||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
if (i < size) {
|
||||
VF out0a, out1a, out2a, out3a, out0b, out1b, out2b, out3b;
|
||||
out0a = hn::LoadN(df, out + i + out_offsets[0], size - i);
|
||||
out1a = hn::LoadN(df, out + i + out_offsets[1], size - i);
|
||||
out2a = hn::LoadN(df, out + i + out_offsets[2], size - i);
|
||||
out3a = hn::LoadN(df, out + i + out_offsets[3], size - i);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
VF scale1 = hn::Set(df, scales[1]);
|
||||
VF scale2 = hn::Set(df, scales[2]);
|
||||
VF scale3 = hn::Set(df, scales[3]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
out1a = hn::Mul(out1a, scale1);
|
||||
out2a = hn::Mul(out2a, scale2);
|
||||
out3a = hn::Mul(out3a, scale3);
|
||||
if (i + NF < size) {
|
||||
out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF);
|
||||
out1b = hn::LoadN(df, out + i + NF + out_offsets[1], size - i - NF);
|
||||
out2b = hn::LoadN(df, out + i + NF + out_offsets[2], size - i - NF);
|
||||
out3b = hn::LoadN(df, out + i + NF + out_offsets[3], size - i - NF);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
out1b = hn::Mul(out1b, scale1);
|
||||
out2b = hn::Mul(out2b, scale2);
|
||||
out3b = hn::Mul(out3b, scale3);
|
||||
} else {
|
||||
out0b = hn::Zero(df);
|
||||
out1b = hn::Zero(df);
|
||||
out2b = hn::Zero(df);
|
||||
out3b = hn::Zero(df);
|
||||
}
|
||||
// Note that v_bf is always padded, so we can always load 2 * NF elements.
|
||||
MulAddNLanesVT4(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out1a,
|
||||
out2a, out3a, out0b, out1b, out2b, out3b);
|
||||
hn::StoreN(out0a, df, out + i + out_offsets[0], size - i);
|
||||
hn::StoreN(out1a, df, out + i + out_offsets[1], size - i);
|
||||
hn::StoreN(out2a, df, out + i + out_offsets[2], size - i);
|
||||
hn::StoreN(out3a, df, out + i + out_offsets[3], size - i);
|
||||
if (i + NF < size) {
|
||||
hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF);
|
||||
hn::StoreN(out1b, df, out + i + NF + out_offsets[1], size - i - NF);
|
||||
hn::StoreN(out2b, df, out + i + NF + out_offsets[2], size - i - NF);
|
||||
hn::StoreN(out3b, df, out + i + NF + out_offsets[3], size - i - NF);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
|
|
@ -743,26 +785,33 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT1Mem(
|
|||
size_t i = 0;
|
||||
while (i + NF * 2 <= size) {
|
||||
VF out0a, out0b;
|
||||
out0a = hn::Load(df, out + i + out_offsets[0]);
|
||||
out0a = hn::LoadU(df, out + i + out_offsets[0]);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
out0b = hn::Load(df, out + i + NF + out_offsets[0]);
|
||||
out0b = hn::LoadU(df, out + i + NF + out_offsets[0]);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
|
||||
hn::Store(out0a, df, out + i + out_offsets[0]);
|
||||
hn::Store(out0b, df, out + i + NF + out_offsets[0]);
|
||||
hn::StoreU(out0a, df, out + i + out_offsets[0]);
|
||||
hn::StoreU(out0b, df, out + i + NF + out_offsets[0]);
|
||||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
}
|
||||
while (i < size) {
|
||||
float sum = out[i + out_offsets[0]] * scales[0];
|
||||
const BF16* HWY_RESTRICT v_local = v_bf;
|
||||
for (size_t lane = 0; lane < HWY_MIN(num_lanes, 2 * NF);
|
||||
++lane, v_local += 2 * NF) {
|
||||
sum += hwy::ConvertScalarTo<float>(*v_local) * c_mem[lane];
|
||||
if (i < size) {
|
||||
VF out0a, out0b;
|
||||
out0a = hn::LoadN(df, out + i + out_offsets[0], size - i);
|
||||
VF scale0 = hn::Set(df, scales[0]);
|
||||
out0a = hn::Mul(out0a, scale0);
|
||||
if (i + NF < size) {
|
||||
out0b = hn::LoadN(df, out + i + NF + out_offsets[0], size - i - NF);
|
||||
out0b = hn::Mul(out0b, scale0);
|
||||
} else {
|
||||
out0b = hn::Zero(df);
|
||||
}
|
||||
MulAddNLanesVT1(df, v_bf, c_mem, HWY_MIN(num_lanes, 2 * NF), out0a, out0b);
|
||||
hn::StoreN(out0a, df, out + i + out_offsets[0], size - i);
|
||||
if (i + NF < size) {
|
||||
hn::StoreN(out0b, df, out + i + NF + out_offsets[0], size - i - NF);
|
||||
}
|
||||
++i;
|
||||
++v_bf;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ const char* ZoneName(Zones zone) {
|
|||
return "FlashAttention.FlashAttention";
|
||||
case Zones::kFlashAttentionInclusive:
|
||||
return "FlashAttention.Inclusive";
|
||||
case Zones::kVitFlashAttentionInclusive:
|
||||
return "Vit.FlashAttention.Inclusive";
|
||||
case Zones::kFlashAttentionRmsNormAndPositionalEncoding:
|
||||
return "FlashAttention.RMSNormAndPositionalEncoding";
|
||||
case Zones::kFlashAttentionTileFlashAttention1:
|
||||
|
|
@ -106,6 +108,7 @@ const char* ZoneName(Zones zone) {
|
|||
hwy::ProfilerFlags ZoneFlags(Zones zone) {
|
||||
switch (zone) {
|
||||
case Zones::kFlashAttentionInclusive:
|
||||
case Zones::kVitFlashAttentionInclusive:
|
||||
case Zones::kGenAttention:
|
||||
case Zones::kGenAttentionComputeQKV:
|
||||
case Zones::kGenAttentionDotSoftmaxWeightedSumInclusive:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ namespace gcpp {
|
|||
enum class Zones { // Keep sorted
|
||||
kFlashAttentionFlashAttention,
|
||||
kFlashAttentionInclusive,
|
||||
kVitFlashAttentionInclusive,
|
||||
kFlashAttentionRmsNormAndPositionalEncoding,
|
||||
kFlashAttentionTileFlashAttention1,
|
||||
kFlashAttentionTileFlashAttention4,
|
||||
|
|
|
|||
Loading…
Reference in New Issue