Implementation of tiled attention with bf16 and circular buffers which reduces memory requirements by 4x on longer context on gemma models.

It also supports better parallelism for small batch sizes / small models.
It also is able to utilize VDPBF16PS for nice 2x improvement on avx512

PiperOrigin-RevId: 874517319
This commit is contained in:
Krzysztof Rymski 2026-02-24 03:26:23 -08:00 committed by Copybara-Service
parent 463a3682be
commit df162ead7c
15 changed files with 3056 additions and 38 deletions

View File

@ -652,10 +652,12 @@ cc_library(
name = "gemma_lib",
srcs = [
"gemma/gemma.cc",
"gemma/tiled_attention.cc",
"gemma/vit.cc",
],
hdrs = [
"gemma/gemma.h",
"gemma/tiled_attention.h",
"gemma/vit.h",
],
exec_properties = {

View File

@ -93,6 +93,8 @@ set(SOURCES
gemma/model_store.h
gemma/tensor_info.cc
gemma/tensor_info.h
gemma/tiled_attention.cc
gemma/tiled_attention.h
gemma/tokenizer.cc
gemma/tokenizer.h
gemma/vit.cc
@ -171,20 +173,20 @@ install(TARGETS libgemma DESTINATION lib)
if(BUILD_GEMMA_DLL)
add_library(gemma_shared SHARED ${SOURCES})
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
set_target_properties(gemma_shared PROPERTIES
set_target_properties(gemma_shared PROPERTIES
PREFIX ""
OUTPUT_NAME "gemma"
)
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
target_include_directories(gemma_shared PUBLIC ./)
target_link_libraries(gemma_shared PRIVATE
target_link_libraries(gemma_shared PRIVATE
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
)
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(gemma_shared
PRIVATE
target_compile_definitions(gemma_shared
PRIVATE
GEMMA_EXPORTS
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)

View File

@ -153,6 +153,14 @@ struct AttentionActivations {
// Accumulation of attention outputs over heads
MatStorageT<BF16> att_sums;
MatStorageT<float> k_tile_vec;
MatStorageT<float> v_tile_vec;
std::vector<MatStorageT<float>> sub_task_att_out;
std::vector<AlignedFloatVector>
sub_task_exp_denominator_sums;
std::vector<AlignedFloatVector>
sub_task_max_logits;
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
@ -244,6 +252,16 @@ struct AttentionActivationsPtrs {
// Accumulation of attention outputs over heads, size batch_size x
// model_dim.
MatPtrT<BF16> att_sums;
// Stores intermediate results of computing QKV,
// [qbatch * kv_heads , k_tile_size * qkv_dim]
MatPtrT<float> k_tile_vec;
MatPtrT<float> v_tile_vec;
// Used by TiledFlashAttention to store intermediate results.
std::vector<MatStorageT<float>>* sub_task_att_out;
std::vector<AlignedFloatVector>*
sub_task_exp_denominator_sums;
std::vector<AlignedFloatVector>*
sub_task_max_logits;
// Inverse timescales for RoPE computation.
MatPtrT<float> inv_timescale;
// Inverse timescales for global RoPE computation.

View File

@ -83,6 +83,8 @@ static inline bool EnumValid(LayerAttentionType type) {
enum class AttentionImpl {
kOld,
kFlash,
kFlashTransposedQs,
kFlashTransposedQsBF16,
kSentinel,
};
@ -108,6 +110,8 @@ static inline int AttentionImplToFlags(AttentionImpl impl,
case AttentionImpl::kOld:
return kAttentionUseOld;
case AttentionImpl::kFlash:
case AttentionImpl::kFlashTransposedQs:
case AttentionImpl::kFlashTransposedQsBF16:
default:
return 0;
}

View File

@ -921,6 +921,620 @@ Tile4FlashState TileFlashAttention4(
return state;
}
template <int kNumQueries, typename Q_T, class DQ_T, class VQ_T = hn::Vec<DQ_T>,
typename T>
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth(
DQ_T df, const Q_T* HWY_RESTRICT q, const Q_T* HWY_RESTRICT q2,
const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VQ_T& sum0_p0,
VQ_T& sum0_p1, VQ_T& sum1_p0, VQ_T& sum1_p1, VQ_T& sum2_p0, VQ_T& sum2_p1,
VQ_T& sum3_p0, VQ_T& sum3_p1, VQ_T& sum4_p0, VQ_T& sum4_p1, VQ_T& sum5_p0,
VQ_T& sum5_p1, VQ_T& sum6_p0, VQ_T& sum6_p1, VQ_T& sum7_p0, VQ_T& sum7_p1) {
const PackedSpan<const T> k_transposed_span =
MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim);
HWY_DASSERT(kNumQueries <= 8);
HWY_DASSERT(gcpp::KVCache::kTileSize >=
hn::Lanes(df) * 2); // So we can decompress 2 lanes at a time.
sum0_p0 = hn::Zero(df);
sum0_p1 = hn::Zero(df);
if constexpr (kNumQueries >= 2) {
sum1_p0 = hn::Zero(df);
sum1_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 3) {
sum2_p0 = hn::Zero(df);
sum2_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 4) {
sum3_p0 = hn::Zero(df);
sum3_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 5) {
sum4_p0 = hn::Zero(df);
sum4_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 6) {
sum5_p0 = hn::Zero(df);
sum5_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 7) {
sum6_p0 = hn::Zero(df);
sum6_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 8) {
sum7_p0 = hn::Zero(df);
sum7_p1 = hn::Zero(df);
}
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
kNumQueries - kFirstHalfAmountOfQueries;
HWY_UNROLL(1)
for (size_t i = 0; i < qkv_dim; ++i) {
VQ_T k_vec1, k_vec2;
if constexpr (HWY_TARGET == HWY_AVX2) {
hwy::Prefetch(k_transposed_span.ptr + (i + 3) * gcpp::KVCache::kTileSize);
hwy::Prefetch(k_transposed_span.ptr + (i + 4) * gcpp::KVCache::kTileSize);
}
Decompress2(df, k_transposed_span, i * gcpp::KVCache::kTileSize, k_vec1,
k_vec2);
sum0_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p0);
sum0_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p1);
if constexpr (kNumQueries >= 2) {
sum1_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p0);
sum1_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p1);
}
if constexpr (kNumQueries >= 3) {
sum2_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p0);
sum2_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p1);
}
if constexpr (kNumQueries >= 4) {
sum3_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p0);
sum3_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p1);
}
if constexpr (kNumQueries >= 5) {
sum4_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p0);
sum4_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p1);
}
if constexpr (kNumQueries >= 6) {
sum5_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p0);
sum5_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p1);
}
if constexpr (kNumQueries >= 7) {
sum6_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p0);
sum6_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p1);
}
if constexpr (kNumQueries >= 8) {
sum7_p0 = hn::MulAdd(
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p0);
sum7_p1 = hn::MulAdd(
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p1);
}
}
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename T>
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16(
DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2,
const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0,
VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1,
VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0,
VF& sum5_p1, VF& sum6_p0, VF& sum6_p1, VF& sum7_p0, VF& sum7_p1) {
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
const PackedSpan<const T> k_transposed_span =
MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim);
[[maybe_unused]] HWY_LANES_CONSTEXPR size_t lanes_bf16 = hn::Lanes(dbf);
HWY_DASSERT(hn::Lanes(dbf) <= gcpp::KVCache::kTileSize);
HWY_DASSERT(kNumQueries <= 8);
HWY_DASSERT(gcpp::KVCache::kTileSize >=
hn::Lanes(df) * 2); // So we can decompress 2 lanes at a time.
sum0_p0 = hn::Zero(df);
sum0_p1 = hn::Zero(df);
if constexpr (kNumQueries >= 2) {
sum1_p0 = hn::Zero(df);
sum1_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 3) {
sum2_p0 = hn::Zero(df);
sum2_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 4) {
sum3_p0 = hn::Zero(df);
sum3_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 5) {
sum4_p0 = hn::Zero(df);
sum4_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 6) {
sum5_p0 = hn::Zero(df);
sum5_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 7) {
sum6_p0 = hn::Zero(df);
sum6_p1 = hn::Zero(df);
}
if constexpr (kNumQueries >= 8) {
sum7_p0 = hn::Zero(df);
sum7_p1 = hn::Zero(df);
}
VF helper_sum0_p0 = hn::Zero(df), helper_sum0_p1 = hn::Zero(df);
VF helper_sum1_p0 = hn::Zero(df), helper_sum1_p1 = hn::Zero(df);
VF helper_sum2_p0 = hn::Zero(df), helper_sum2_p1 = hn::Zero(df);
VF helper_sum3_p0 = hn::Zero(df), helper_sum3_p1 = hn::Zero(df);
VF helper_sum4_p0 = hn::Zero(df), helper_sum4_p1 = hn::Zero(df);
VF helper_sum5_p0 = hn::Zero(df), helper_sum5_p1 = hn::Zero(df);
VF helper_sum6_p0 = hn::Zero(df), helper_sum6_p1 = hn::Zero(df);
VF helper_sum7_p0 = hn::Zero(df), helper_sum7_p1 = hn::Zero(df);
const float* q_float_ptr = HWY_RCAST_ALIGNED(const float*, q);
const float* q2_float_ptr = HWY_RCAST_ALIGNED(const float*, q2);
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
kNumQueries - kFirstHalfAmountOfQueries;
for (size_t i = 0; i < qkv_dim / 2; i++) {
VBF k_vec1, k_vec2;
Decompress2(dbf, k_transposed_span, i * 2 * gcpp::KVCache::kTileSize,
k_vec1, k_vec2);
VF q_0_as_float = hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries]);
VBF q_0 = hn::BitCast(dbf, q_0_as_float);
sum0_p0 =
hn::ReorderWidenMulAccumulate(df, k_vec1, q_0, sum0_p0, helper_sum0_p0);
sum0_p1 =
hn::ReorderWidenMulAccumulate(df, k_vec2, q_0, sum0_p1, helper_sum0_p1);
if constexpr (kNumQueries >= 2) {
VF q_1_as_float =
hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 1]);
VBF q_1 = hn::BitCast(dbf, q_1_as_float);
sum1_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_1, sum1_p0,
helper_sum1_p0);
sum1_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_1, sum1_p1,
helper_sum1_p1);
}
if constexpr (kNumQueries >= 3) {
VF q_2_as_float =
hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 2]);
VBF q_2 = hn::BitCast(dbf, q_2_as_float);
sum2_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_2, sum2_p0,
helper_sum2_p0);
sum2_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_2, sum2_p1,
helper_sum2_p1);
}
if constexpr (kNumQueries >= 4) {
VF q_3_as_float =
hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 3]);
VBF q_3 = hn::BitCast(dbf, q_3_as_float);
sum3_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_3, sum3_p0,
helper_sum3_p0);
sum3_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_3, sum3_p1,
helper_sum3_p1);
}
if constexpr (kNumQueries >= 5) {
VF q_4_as_float =
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 0]);
VBF q_4 = hn::BitCast(dbf, q_4_as_float);
sum4_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_4, sum4_p0,
helper_sum4_p0);
sum4_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_4, sum4_p1,
helper_sum4_p1);
}
if constexpr (kNumQueries >= 6) {
VF q_5_as_float =
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 1]);
VBF q_5 = hn::BitCast(dbf, q_5_as_float);
sum5_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_5, sum5_p0,
helper_sum5_p0);
sum5_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_5, sum5_p1,
helper_sum5_p1);
}
if constexpr (kNumQueries >= 7) {
VF q_6_as_float =
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 2]);
VBF q_6 = hn::BitCast(dbf, q_6_as_float);
sum6_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_6, sum6_p0,
helper_sum6_p0);
sum6_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_6, sum6_p1,
helper_sum6_p1);
}
if constexpr (kNumQueries >= 8) {
VF q_7_as_float =
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 3]);
VBF q_7 = hn::BitCast(dbf, q_7_as_float);
sum7_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_7, sum7_p0,
helper_sum7_p0);
sum7_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_7, sum7_p1,
helper_sum7_p1);
}
}
#if HWY_NATIVE_DOT_BF16 == 0
sum0_p0 = hn::Add(sum0_p0, helper_sum0_p0);
sum0_p1 = hn::Add(sum0_p1, helper_sum0_p1);
if constexpr (kNumQueries >= 2) {
sum1_p0 = hn::Add(sum1_p0, helper_sum1_p0);
sum1_p1 = hn::Add(sum1_p1, helper_sum1_p1);
}
if constexpr (kNumQueries >= 3) {
sum2_p0 = hn::Add(sum2_p0, helper_sum2_p0);
sum2_p1 = hn::Add(sum2_p1, helper_sum2_p1);
}
if constexpr (kNumQueries >= 4) {
sum3_p0 = hn::Add(sum3_p0, helper_sum3_p0);
sum3_p1 = hn::Add(sum3_p1, helper_sum3_p1);
}
if constexpr (kNumQueries >= 5) {
sum4_p0 = hn::Add(sum4_p0, helper_sum4_p0);
sum4_p1 = hn::Add(sum4_p1, helper_sum4_p1);
}
if constexpr (kNumQueries >= 6) {
sum5_p0 = hn::Add(sum5_p0, helper_sum5_p0);
sum5_p1 = hn::Add(sum5_p1, helper_sum5_p1);
}
if constexpr (kNumQueries >= 7) {
sum6_p0 = hn::Add(sum6_p0, helper_sum6_p0);
sum6_p1 = hn::Add(sum6_p1, helper_sum6_p1);
}
if constexpr (kNumQueries >= 8) {
sum7_p0 = hn::Add(sum7_p0, helper_sum7_p0);
sum7_p1 = hn::Add(sum7_p1, helper_sum7_p1);
}
#endif
}
template <int kVTileSize, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap,
VF& x0, VF& x1, VF& x2, VF& x3, VF& x4,
VF& x5, VF& x6, VF& x7) {
if (att_cap > 0.0f) {
VF cap = hn::Set(df, att_cap);
VF one_over_cap_vec = hn::Set(df, one_over_cap);
x0 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x0, one_over_cap_vec)));
if constexpr (kVTileSize >= 2) {
x1 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x1, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 3) {
x2 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x2, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 4) {
x3 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x3, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 5) {
x4 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x4, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 6) {
x5 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x5, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 7) {
x6 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x6, one_over_cap_vec)));
}
if constexpr (kVTileSize >= 8) {
x7 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x7, one_over_cap_vec)));
}
}
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename DU,
class VU = hn::Vec<DU>>
static HWY_NOINLINE void ApplyMasking(
DF df, DU du, size_t position,
const size_t* HWY_RESTRICT first_pos_per_query,
const size_t* HWY_RESTRICT last_pos_per_query, VF& x0_p0, VF& x0_p1,
VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0,
VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0,
VF& x7_p1) {
VU lane_indices = hn::Iota(du, 0);
HWY_LANES_CONSTEXPR size_t kTileSize = hn::Lanes(df);
auto per_lane_pos_p0 = hn::Add(hn::Set(du, position), lane_indices);
auto per_lane_pos_p1 =
hn::Add(hn::Set(du, position + kTileSize), lane_indices);
VF neg_inf = hn::Set(df, kNegInf);
auto apply_mask_for_query = [&](int query_idx, VF& x_p0, VF& x_p1) HWY_ATTR {
const size_t first_pos = first_pos_per_query[query_idx];
const size_t last_pos = last_pos_per_query[query_idx];
auto valid_tokens_mask_p0 = hn::Ge(per_lane_pos_p0, hn::Set(du, first_pos));
valid_tokens_mask_p0 = hn::And(
valid_tokens_mask_p0, hn::Le(per_lane_pos_p0, hn::Set(du, last_pos)));
x_p0 =
hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p0), x_p0, neg_inf);
auto valid_tokens_mask_p1 = hn::Ge(per_lane_pos_p1, hn::Set(du, first_pos));
valid_tokens_mask_p1 = hn::And(
valid_tokens_mask_p1, hn::Le(per_lane_pos_p1, hn::Set(du, last_pos)));
x_p1 =
hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p1), x_p1, neg_inf);
};
if constexpr (kNumQueries >= 1) {
apply_mask_for_query(0, x0_p0, x0_p1);
}
if constexpr (kNumQueries >= 2) {
apply_mask_for_query(1, x1_p0, x1_p1);
}
if constexpr (kNumQueries >= 3) {
apply_mask_for_query(2, x2_p0, x2_p1);
}
if constexpr (kNumQueries >= 4) {
apply_mask_for_query(3, x3_p0, x3_p1);
}
if constexpr (kNumQueries >= 5) {
apply_mask_for_query(4, x4_p0, x4_p1);
}
if constexpr (kNumQueries >= 6) {
apply_mask_for_query(5, x5_p0, x5_p1);
}
if constexpr (kNumQueries >= 7) {
apply_mask_for_query(6, x6_p0, x6_p1);
}
if constexpr (kNumQueries >= 8) {
apply_mask_for_query(7, x7_p0, x7_p1);
}
}
// Performs tiled flash attention for arbitrary number of queries
// It depends on kv being tiled.
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
// It moves NF*2 timesteps forward in kv at a time.
// Args:
// kvs - hwy::Span of MatPtrT<KV_T> of shape (kvs, (tile_count, qkv_dim *
// kTileSize * 2)) This span allows to pass kv cache that is not contiguous,
// all except for the last one should have theirs row count be true,
// as it will be used to figure out when to switch to the next one.
// q_T_in_groups_up_to_4 - Span of float* All except last float*
// should have (qkv_dim, 4) Last one can have any size up to 4.
// start_pos_per_query - start position in kv to start attention from ()
// last_pos_per_query - last position in kv to attend to (exclusive)
// queries_per_timestep - how many queries begin/end on the same timestep
// attention_shape - see struct definition for more details.
// att_cap - soft cap on attention logits
// att_out - MatPtrT<float> of shape (q_count, qkv_dim)
// exp_denominator_sums and max_logits: float* of shape:
// (RountedUpTo(q_count,4),)
// Need to be have multiple of 4 elements alocated and
// be initizalized If you need to compute over multiple chunks of kv's you can
// keep values between calls to this function and avoid explicit merge.
template <typename KV_T, typename Q_T>
HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
const hwy::Span<const MatPtrT<KV_T>> kvs, int q_count,
const hwy::Span<const Q_T * HWY_RESTRICT> q_T_in_groups_up_to_4,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
using DU = hn::ScalableTag<uint32_t>;
[[maybe_unused]] const DU du;
constexpr int kTileSize = gcpp::KVCache::kTileSize;
HWY_LANES_CONSTEXPR size_t kHTileSize = hn::Lanes(df);
constexpr int kNumQueriesPerGroup = 4;
constexpr int kNumQueriesPerLoop =
(!HWY_ARCH_X86 || (HWY_TARGET <= HWY_AVX3)) ? 8 : 4;
constexpr int kNumGroupsPerLoop = kNumQueriesPerLoop / kNumQueriesPerGroup;
const size_t full_groups_of_queries = q_count / kNumQueriesPerGroup;
const size_t num_loops = hwy::DivCeil(q_count, kNumQueriesPerLoop);
const size_t qkv_dim = att_out.Cols();
HWY_DASSERT(kHTileSize <= hn::MaxLanes(df));
HWY_LANES_CONSTEXPR size_t step_size = kHTileSize * 2;
size_t smallest_start_pos = std::numeric_limits<size_t>::max();
size_t largest_last_pos = std::numeric_limits<size_t>::min();
for (size_t i = 0; i < start_pos_per_query.size(); ++i) {
smallest_start_pos = std::min(smallest_start_pos, start_pos_per_query[i]);
largest_last_pos = std::max(largest_last_pos, last_pos_per_query[i]);
}
// start / end positions per group of 4 queries.
std::vector<size_t, hwy::AlignedAllocator<size_t>> pos_data(num_loops * 4);
hwy::Span<size_t> min_start_pos_per_group(pos_data.data(), num_loops);
hwy::Span<size_t> max_start_pos_per_group(pos_data.data() + num_loops,
num_loops);
hwy::Span<size_t> min_last_pos_per_group(pos_data.data() + 2 * num_loops,
num_loops);
hwy::Span<size_t> max_last_pos_per_group(pos_data.data() + 3 * num_loops,
num_loops);
for (size_t i = 0; i < num_loops; ++i) {
size_t min_start = std::numeric_limits<size_t>::max();
size_t max_start = 0;
size_t min_last = std::numeric_limits<size_t>::max();
size_t max_last = 0;
for (int j = 0; j < kNumQueriesPerLoop; ++j) {
if (i * kNumQueriesPerLoop + j < q_count) {
min_start = std::min(min_start,
start_pos_per_query[i * kNumQueriesPerLoop + j]);
max_start = std::max(max_start,
start_pos_per_query[i * kNumQueriesPerLoop + j]);
min_last =
std::min(min_last, last_pos_per_query[i * kNumQueriesPerLoop + j]);
max_last =
std::max(max_last, last_pos_per_query[i * kNumQueriesPerLoop + j]);
}
}
min_start_pos_per_group[i] = min_start;
max_start_pos_per_group[i] = max_start;
min_last_pos_per_group[i] = min_last;
max_last_pos_per_group[i] = max_last;
}
const size_t base_pos = smallest_start_pos - (smallest_start_pos % kTileSize);
const size_t rem = smallest_start_pos % kTileSize;
const size_t num_skipped_sub_tiles = rem / step_size;
size_t position = base_pos + num_skipped_sub_tiles * step_size;
[[maybe_unused]] float one_over_cap = 1.0f / att_cap;
std::vector<MatPtrT<float>> att_out_per_query;
att_out_per_query.reserve(num_loops);
for (size_t i = 0; i < num_loops; ++i) {
att_out_per_query.emplace_back("att_out",
Extents2D(kNumQueriesPerLoop, qkv_dim));
att_out_per_query.back().SetPtr(att_out.Row(i * kNumQueriesPerLoop),
att_out.Stride());
}
size_t current_kv_start_offset = 0;
size_t current_kv_idx = 0;
auto inner_loop = [&]<int kNumQueries>(int q_group_idx) HWY_ATTR {
int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup);
if (position + step_size <= min_start_pos_per_group[loop_idx] ||
position > max_last_pos_per_group[loop_idx]) {
return;
}
VF x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1;
VF x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1;
const size_t pos_in_tile = position % kTileSize;
// tile base can point to same tile as previous loop iteration, hence no
// HWY_RESTRICT
// KVs are unaligned and we only use unaligned loads in this implementation.
const KV_T* tile_base =
reinterpret_cast<const KV_T*>(kvs[current_kv_idx].RowBytes(
(position - current_kv_start_offset) / kTileSize));
const KV_T* v_tile =
tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim;
const Q_T* q_group = q_T_in_groups_up_to_4[q_group_idx];
const Q_T* q2_group = nullptr;
if (kNumQueries > 4) {
q2_group = q_T_in_groups_up_to_4[q_group_idx + 1];
}
if constexpr (IsF32<Q_T>()) {
const KV_T* k_transposed_tile = tile_base + pos_in_tile;
QDotKTilexUpTo8TransposedKDoubleWidth<kNumQueries>(
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
} else if constexpr (IsBF16<Q_T>()) {
const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2;
QDotKTilexUpTo8TransposedKDoubleWidthBF16<kNumQueries>(
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
} else {
static_assert(
false,
"Query type type not supported, only float and BF16 are supported");
}
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
kNumQueries - kFirstHalfAmountOfQueries;
ApplySoftCap<kFirstHalfAmountOfQueries * 2>(
df, att_cap, one_over_cap, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0,
x_2_p_1, x_3_p_0, x_3_p_1);
if constexpr (kNumQueries > 4) {
ApplySoftCap<kSecondHalfAmountOfQueries * 2>(
df, att_cap, one_over_cap, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}
if (position < max_start_pos_per_group[loop_idx] ||
position + step_size - 1 > min_last_pos_per_group[loop_idx]) {
ApplyMasking<kNumQueries>(
df, du, position,
start_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup,
last_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup,
x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
x_7_p_0, x_7_p_1);
}
HWY_ALIGN float scales[kNumQueriesPerLoop];
// HWY_UNROLL(kNumQueriesPerLoop)
for (size_t i = 0; i < kNumQueriesPerLoop; ++i) {
scales[i] = 1.0f;
}
FlashAttentionTileStepAndApplySoftCap<kNumQueries>(
df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
kNumQueriesPerGroup);
if constexpr (IsF32<Q_T>()) {
MulByConstAndAddTileUpTo8<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
} else if constexpr (IsBF16<Q_T>()) {
MulByConstAndAddTileUpTo8_BF16<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
}
};
while (position <= largest_last_pos) {
while (position - current_kv_start_offset >=
kvs[current_kv_idx].Rows() * kTileSize) {
current_kv_start_offset += kvs[current_kv_idx].Rows() * kTileSize;
current_kv_idx++;
}
int group_idx = 0;
for (; group_idx + kNumGroupsPerLoop <= full_groups_of_queries;
group_idx += kNumGroupsPerLoop) {
inner_loop.template operator()<kNumQueriesPerLoop>(group_idx);
}
if (group_idx < full_groups_of_queries) {
inner_loop.template operator()<4>(group_idx);
group_idx++;
}
switch (q_count % kNumQueriesPerGroup) {
case 1:
inner_loop.template operator()<1>(group_idx);
break;
case 2:
inner_loop.template operator()<2>(group_idx);
break;
case 3:
inner_loop.template operator()<3>(group_idx);
break;
default:
break;
}
position += step_size;
}
}
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
hwy::Span<const MatPtr> kvs, int q_count,
const hwy::Span<const float* HWY_RESTRICT> q_T_in_groups_up_to_4,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
return TileFlashAttentionReturnExpSumsAndMaxLogits(
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
});
}
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
hwy::Span<const MatPtr> kvs, int q_count,
const hwy::Span<const BF16 * HWY_RESTRICT> q_T_in_groups_up_to_4,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
return TileFlashAttentionReturnExpSumsAndMaxLogits(
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
});
}
// Rounds n to a number that can be used as the number of Q rows in a tile
// of flash attention.
static size_t RoundToSuitablePowerOf2(size_t n) {

View File

@ -22,46 +22,78 @@
#include <cstdint>
#include "gemma/configs.h"
#include "gemma/flash_structs.h"
#include "gemma/kv_cache.h"
#include "gemma/query.h"
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
ThreadingContext& ctx, const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx, AttentionImpl attention_impl); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void RMSNormAndPositionalEncoding( \
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
const MatPtr& query_norm_scale, size_t layer_idx, \
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
\
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
const BF16* HWY_RESTRICT q, \
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
float* HWY_RESTRICT att_out, \
ThreadingContext& ctx, size_t worker); \
\
Tile4FlashState TileFlashAttention4( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const MatPtrT<KV_t>& k, size_t start_pos, \
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
const size_t worker); \
\
void TileFlashAttention( \
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, \
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, \
const size_t min_last_pos, const size_t max_last_pos, \
const MatPtrT<KV_t>& v, const size_t layer_idx, \
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
const size_t worker); \
\
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
size_t total_tasks, size_t target_parallelism); \
\
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
size_t layer_idx, const MatPtr& query_norm_scale, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx, AttentionImpl attention_impl); \
\
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \
hwy::Span<const MatPtr> kvs, int q_count, \
const hwy::Span<const float* HWY_RESTRICT> q_T_in_groups_up_to_4, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \
hwy::Span<const MatPtr> kvs, int q_count, \
const hwy::Span<const BF16 * HWY_RESTRICT> q_T_in_groups_up_to_4, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the

View File

@ -181,6 +181,298 @@ void TestAttention() {
TestFlashAttention(256);
}
const std::vector<float> exp_denominator_sums_gold = {
58.722088f, 58.445938f, 58.17153f, 57.89886f,
58.580994f, 58.302643f, 58.026085f, 57.751308f};
const std::vector<float> max_logits_gold = {
0.009613638f, 0.019227259f, 0.02884084f, 0.038454376f,
0.04888253f, 0.058658823f, 0.06843502f, 0.078211054f};
const std::vector<float> att_out_gold = {
0.600945, 0.300472, 0.200315, 0.150236, 0.120189, 0.100158, 0.085849,
0.075118, 0.066772, 0.060095, 0.054631, 0.050079, 0.046227, 0.042925,
0.040063, 0.037559, 0.035350, 0.033386, 0.031629, 0.030047, 0.028616,
0.027316, 0.026128, 0.025039, 0.024038, 0.023113, 0.022257, 0.021462,
0.020722, 0.020032, 0.019385, 0.018780, 0.018210, 0.017675, 0.017170,
0.016693, 0.016242, 0.015814, 0.015409, 0.015024, 0.014657, 0.014308,
0.013975, 0.013658, 0.013354, 0.013064, 0.012786, 0.012520, 0.012264,
0.012019, 0.011783, 0.011557, 0.011339, 0.011129, 0.010926, 0.010731,
0.010543, 0.010361, 0.010186, 0.010016, 0.009852, 0.009693, 0.009539,
0.009390, 0.601890, 0.300945, 0.200630, 0.150473, 0.120378, 0.100315,
0.085984, 0.075236, 0.066877, 0.060189, 0.054717, 0.050158, 0.046299,
0.042992, 0.040126, 0.037618, 0.035405, 0.033438, 0.031678, 0.030095,
0.028661, 0.027359, 0.026169, 0.025079, 0.024076, 0.023150, 0.022292,
0.021496, 0.020755, 0.020063, 0.019416, 0.018809, 0.018239, 0.017703,
0.017197, 0.016719, 0.016267, 0.015839, 0.015433, 0.015047, 0.014680,
0.014331, 0.013997, 0.013679, 0.013375, 0.013085, 0.012806, 0.012539,
0.012283, 0.012038, 0.011802, 0.011575, 0.011356, 0.011146, 0.010943,
0.010748, 0.010559, 0.010377, 0.010202, 0.010032, 0.009867, 0.009708,
0.009554, 0.009405, 0.602835, 0.301418, 0.200945, 0.150709, 0.120567,
0.100473, 0.086119, 0.075354, 0.066982, 0.060284, 0.054803, 0.050236,
0.046372, 0.043060, 0.040189, 0.037677, 0.035461, 0.033491, 0.031728,
0.030142, 0.028706, 0.027402, 0.026210, 0.025118, 0.024113, 0.023186,
0.022327, 0.021530, 0.020787, 0.020095, 0.019446, 0.018839, 0.018268,
0.017730, 0.017224, 0.016745, 0.016293, 0.015864, 0.015457, 0.015071,
0.014703, 0.014353, 0.014019, 0.013701, 0.013396, 0.013105, 0.012826,
0.012559, 0.012303, 0.012057, 0.011820, 0.011593, 0.011374, 0.011164,
0.010961, 0.010765, 0.010576, 0.010394, 0.010218, 0.010047, 0.009883,
0.009723, 0.009569, 0.009419, 0.603780, 0.301890, 0.201260, 0.150945,
0.120756, 0.100630, 0.086254, 0.075473, 0.067087, 0.060378, 0.054889,
0.050315, 0.046445, 0.043127, 0.040252, 0.037736, 0.035516, 0.033543,
0.031778, 0.030189, 0.028751, 0.027445, 0.026251, 0.025158, 0.024151,
0.023222, 0.022362, 0.021564, 0.020820, 0.020126, 0.019477, 0.018868,
0.018296, 0.017758, 0.017251, 0.016772, 0.016318, 0.015889, 0.015482,
0.015095, 0.014726, 0.014376, 0.014041, 0.013722, 0.013417, 0.013126,
0.012846, 0.012579, 0.012322, 0.012076, 0.011839, 0.011611, 0.011392,
0.011181, 0.010978, 0.010782, 0.010593, 0.010410, 0.010234, 0.010063,
0.009898, 0.009738, 0.009584, 0.009434, 0.614887, 0.307443, 0.204962,
0.153722, 0.122977, 0.102481, 0.087841, 0.076861, 0.068321, 0.061489,
0.055899, 0.051241, 0.047299, 0.043920, 0.040992, 0.038430, 0.036170,
0.034160, 0.032362, 0.030744, 0.029280, 0.027949, 0.026734, 0.025620,
0.024595, 0.023649, 0.022774, 0.021960, 0.021203, 0.020496, 0.019835,
0.019215, 0.018633, 0.018085, 0.017568, 0.017080, 0.016619, 0.016181,
0.015766, 0.015372, 0.014997, 0.014640, 0.014300, 0.013975, 0.013664,
0.013367, 0.013083, 0.012810, 0.012549, 0.012298, 0.012057, 0.011825,
0.011602, 0.011387, 0.011180, 0.010980, 0.010787, 0.010601, 0.010422,
0.010248, 0.010080, 0.009918, 0.009760, 0.009608, 0.615864, 0.307932,
0.205288, 0.153966, 0.123173, 0.102644, 0.087981, 0.076983, 0.068429,
0.061586, 0.055988, 0.051322, 0.047374, 0.043990, 0.041058, 0.038491,
0.036227, 0.034215, 0.032414, 0.030793, 0.029327, 0.027994, 0.026777,
0.025661, 0.024635, 0.023687, 0.022810, 0.021995, 0.021237, 0.020529,
0.019867, 0.019246, 0.018663, 0.018114, 0.017596, 0.017107, 0.016645,
0.016207, 0.015791, 0.015397, 0.015021, 0.014663, 0.014322, 0.013997,
0.013686, 0.013388, 0.013103, 0.012830, 0.012569, 0.012317, 0.012076,
0.011844, 0.011620, 0.011405, 0.011198, 0.010998, 0.010805, 0.010618,
0.010438, 0.010264, 0.010096, 0.009933, 0.009776, 0.009623, 0.616841,
0.308421, 0.205614, 0.154210, 0.123368, 0.102807, 0.088120, 0.077105,
0.068538, 0.061684, 0.056076, 0.051403, 0.047449, 0.044060, 0.041123,
0.038553, 0.036285, 0.034269, 0.032465, 0.030842, 0.029373, 0.028038,
0.026819, 0.025702, 0.024674, 0.023725, 0.022846, 0.022030, 0.021270,
0.020561, 0.019898, 0.019276, 0.018692, 0.018142, 0.017624, 0.017134,
0.016671, 0.016233, 0.015816, 0.015421, 0.015045, 0.014687, 0.014345,
0.014019, 0.013708, 0.013410, 0.013124, 0.012851, 0.012589, 0.012337,
0.012095, 0.011862, 0.011639, 0.011423, 0.011215, 0.011015, 0.010822,
0.010635, 0.010455, 0.010281, 0.010112, 0.009949, 0.009791, 0.009638,
0.617818, 0.308909, 0.205939, 0.154455, 0.123564, 0.102970, 0.088260,
0.077227, 0.068646, 0.061782, 0.056165, 0.051485, 0.047524, 0.044130,
0.041188, 0.038614, 0.036342, 0.034323, 0.032517, 0.030891, 0.029420,
0.028083, 0.026862, 0.025742, 0.024713, 0.023762, 0.022882, 0.022065,
0.021304, 0.020594, 0.019930, 0.019307, 0.018722, 0.018171, 0.017652,
0.017162, 0.016698, 0.016258, 0.015841, 0.015445, 0.015069, 0.014710,
0.014368, 0.014041, 0.013729, 0.013431, 0.013145, 0.012871, 0.012609,
0.012356, 0.012114, 0.011881, 0.011657, 0.011441, 0.011233, 0.011032,
0.010839, 0.010652, 0.010471, 0.010297, 0.010128, 0.009965, 0.009807,
0.009653};
void TestTiledFlashAttention() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
MatStorageT<float> kv(
"kv",
Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data
for (int i = 0; i < padded_kv_seq_len; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
const int tile_idx = i / gcpp::KVCache::kTileSize;
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
const float val_k = 0.01f * (i + 1) / (j + 1);
const float val_v = 0.02f * (i + 1) / (j + 1);
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset] = val_k;
const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j] = val_v;
}
}
std::vector<float> q_float(4 * qkv_dim);
std::vector<float> q_float2(4 * qkv_dim);
// fill in qs with predictable, synthetic data
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < qkv_dim; j++) {
float val_1 = 0.01f * (i + 1) / (j + 1);
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
q_float[j * 4 + i] = val_1;
q_float2[j * 4 + i] = val_2;
}
}
const float* q_T[2] = {q_float.data(), q_float2.data()};
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
using DF = hn::ScalableTag<float>;
const DF df;
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
std::vector<float> max_logits(num_queries_rounded_to_laness);
for (size_t i = 0; i < num_queries; ++i) {
hwy::ZeroBytes(att_out.Row(i),
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
exp_denominator_sums[i] = 0.0f;
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
}
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
}
}
hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
kvs, num_queries, hwy::Span<const float*>(q_T, 2),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
// TODO: Replace with Other implementation for generating goldens.
// Current values are taken from a point in time where code was run with gemma
// and output looked good. Not ideal but should be good enough to test the
// plumbing and detect regressions.
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f);
}
}
}
void TestTiledFlashAttentionBF16() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
MatStorageT<BF16> kv(
"kv",
Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data
for (int i = 0; i < padded_kv_seq_len; i++) {
for (int j = 0; j < qkv_dim; j+=2) {
const int tile_idx = i / gcpp::KVCache::kTileSize;
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
const float val_k_1 = 0.01f * (i + 1) / (j + 1);
const float val_k_2 = 0.01f * (i + 1) / (j + 2);
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2] =
hwy::ConvertScalarTo<BF16>(val_k_1);
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2 + 1] =
hwy::ConvertScalarTo<BF16>(val_k_2);
}
}
const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
for (int i = 0; i < padded_kv_seq_len; i += 2) {
for (int j = 0; j < qkv_dim; j++) {
const int tile_idx = i / gcpp::KVCache::kTileSize;
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
const float val_v_1 = 0.02f * (i + 1) / (j + 1);
const float val_v_2 = 0.02f * (i + 2) / (j + 1);
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2] =
hwy::ConvertScalarTo<BF16>(val_v_1);
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2 + 1] =
hwy::ConvertScalarTo<BF16>(val_v_2);
}
}
std::vector<BF16> q_float(num_queries_per_timestep * qkv_dim);
std::vector<BF16> q_float2(num_queries_per_timestep * qkv_dim);
// fill in qs with predictable, synthetic data
for (int i = 0; i < num_queries_per_timestep; ++i) {
for (int j = 0; j < qkv_dim; j += 2) {
q_float[j * num_queries_per_timestep + i * 2] =
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 1));
q_float[j * num_queries_per_timestep + i * 2 + 1] =
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 2));
q_float2[j * num_queries_per_timestep + i * 2] =
hwy::ConvertScalarTo<BF16>(
0.01f * (i + num_queries_per_timestep + 1) / (j + 1));
q_float2[j * num_queries_per_timestep + i * 2 + 1] =
hwy::ConvertScalarTo<BF16>(
0.01f * (i + num_queries_per_timestep + 1) / (j + 2));
}
}
const BF16* q_T[2] = {q_float.data(), q_float2.data()};
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
HWY_LANES_CONSTEXPR size_t lanes = 4;
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
std::vector<float> max_logits(num_queries_rounded_to_laness);
for (size_t i = 0; i < num_queries; ++i) {
hwy::ZeroBytes(att_out.Row(i),
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
exp_denominator_sums[i] = 0.0f;
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
}
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
}
}
hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kvs, num_queries, hwy::Span<const BF16*>(q_T, 2),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
// TODO: Replace with Other implementation for generating goldens.
// Current values are taken from a point in time where code was run with gemma
// and output looked good. Not ideal but should be good enough to test the
// plumbing and detect regressions.
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f);
}
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -42,7 +42,8 @@
// After highway.h
#include "gemma/attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "gemma/vit.h" // includes highway.h
#include "gemma/tiled_attention.h" // includes highway.h
#include "gemma/vit.h" // includes highway.h
#ifndef GEMMA_CC_ONCE
#define GEMMA_CC_ONCE
@ -80,6 +81,14 @@ namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (activations.attention_impl == AttentionImpl::kFlashTransposedQs ||
activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
TiledAttention(
activations.attention_impl, num_tokens, layer_idx, layer,
activations.attention, qbatch, env,
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
return;
}
if (type == LayerAttentionType::kGemma) {
// TODO: remove flag to enable FlashAttention.

View File

@ -148,6 +148,14 @@ struct RuntimeConfig {
// Which attention implementation to use.
AttentionImpl attention_impl = AttentionImpl::kFlash;
// Right now it only work for tiled kv cache, implementations.
// If not set, it will be set based on the attention_impl.
// F32 for tiled
// BF16 for tiled bf16
// If you want to use type other than F32 or BF16, you might need to update
// call upcasted.
std::optional<Type> kv_cache_type = {};
// Functions operating on the generated tokens.
StreamFunc stream_token;
BatchStreamFunc batch_stream_token;

View File

@ -51,10 +51,72 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
allocator) {}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const RuntimeConfig& runtime_config,
const Allocator& allocator)
: allocator_(allocator) {
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs ||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
) {
const size_t num_tiles =
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
tiled_seq_len = num_tiles * kTileSize;
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
Type kv_cache_type;
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
) {
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16);
} else {
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
}
auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size,
size_t max_seq_len) {
return hwy::DivCeil(
std::min(max_seq_len, window_size + prefill_tbatch_size), kTileSize);
};
size_t total_num_tiles = 0;
for (size_t window_size : config.attention_window_sizes) {
total_num_tiles +=
num_tiles_per_head(window_size, runtime_config.prefill_tbatch_size,
config.max_seq_len) *
config.layer_configs[0].kv_heads;
}
Extents2D extents(total_num_tiles, tile_length);
compact_kv_cache_ptr = MatPtr("kv_tiled", kv_cache_type, extents);
compact_kv_cache.AllocateFor(compact_kv_cache_ptr, allocator,
MatPadding::kPacked);
total_num_tiles = 0;
kv_head_ptrs.reserve(config.attention_window_sizes.size() *
config.layer_configs[0].kv_heads);
for (size_t window_size : config.attention_window_sizes) {
for (size_t kv = 0; kv < config.layer_configs[0].kv_heads; ++kv) {
size_t num_tiles_per_kv_head =
num_tiles_per_head(window_size, runtime_config.prefill_tbatch_size,
config.max_seq_len);
MatPtr kv_ptr("kv_ptr", kv_cache_type,
Extents2D(num_tiles_per_kv_head, tile_length));
kv_ptr.SetPtr(compact_kv_cache_ptr.RowBytes(total_num_tiles),
compact_kv_cache_ptr.Stride());
kv_head_ptrs.emplace_back(std::move(kv_ptr));
total_num_tiles += num_tiles_per_kv_head;
}
}
} else {
kv_cache = MatStorageT<KV_t>(
"kv",
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
allocator, MatPadding::kOdd);
}
}
KVCache KVCache::Copy() {
KVCache copy(kv_cache.Extents(), allocator_);
CopyMat(kv_cache, copy.kv_cache);
CopyMat(compact_kv_cache_ptr, copy.compact_kv_cache_ptr);
copy.tiled_seq_len = tiled_seq_len;
return copy;
}

View File

@ -31,31 +31,103 @@
namespace gcpp {
using KV_t = float;
struct KVCache;
// A non-owning view of a KVCache.
struct KVCachePtr {
bool IsEmpty() const { return kv_cache.Rows() == 0; }
size_t SeqLen() const;
bool IsTiled() const;
MatPtrT<KV_t> kv_cache;
KVCache* cache = nullptr;
};
struct KVCache {
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const Allocator& allocator);
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const RuntimeConfig& runtime_config, const Allocator& allocator);
// Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit.
KVCache Copy();
size_t SeqLen() const {
if (IsTiled()) {
return tiled_seq_len.value();
}
return kv_cache.Rows();
}
bool IsTiled() const {
return tiled_seq_len.has_value();
}
// This function returns a vector of pointers and handles wraparound for local
// layers.
// You can use this function to get kv's,
// it will slice internal circular buffer and give you parts of it that are in
// order. Keep in mind that this gives out pointers to tiles, and for local
// layers start_pos might be in a middle of the first tile. At start_pos %
// kTileSize
std::vector<MatPtr> GetPointers(int layer_idx, int kv_head_idx,
int num_kv_heads, int start_pos,
bool is_global_layer) {
if (!IsTiled()) {
HWY_ABORT("This function is only meant to be used with tiled KV caches.");
}
MatPtr& source_ptr = kv_head_ptrs[layer_idx * num_kv_heads + kv_head_idx];
if (is_global_layer) {
return {source_ptr};
}
size_t start_tile_mod_window = (start_pos / kTileSize) % source_ptr.Rows();
size_t start_len = source_ptr.Rows() - start_tile_mod_window;
MatPtr start_ptr("kv_start", source_ptr.GetType(),
Extents2D(start_len, source_ptr.Cols()));
start_ptr.SetPtr(source_ptr.RowBytes(start_tile_mod_window),
source_ptr.Cols());
return {start_ptr, source_ptr};
}
static constexpr size_t kTileSize = 32;
std::optional<uint32_t> tiled_seq_len = std::nullopt;
// Default Format
// If tiled_seq_len is not set, then the kv_cache is assumed to be [seq_len,
// layers * kv_heads * qkv_dim * 2].
//
// Tiled Format
// If tiled_seq_len is set, the kv cache is stored in tiled format.
// Allocations must happen in full tiles.
// The order of dimensions on rows is: [layer, kv_head, tile].
// The total number of rows is:
// num_layers * num_kv_heads * (tiled_seq_len / kTileSize).
// Each tile (containing kTileSize elements from the sequence) can be thought
// of as storing K^T and V, where K is shaped [kTileSize, qkv_dim].
// Type erased kv cache. It's compact because local layers are allocated as
// circular buffers.
MatPtr compact_kv_cache_ptr;
MatOwner compact_kv_cache;
// Pointers to the raw KV storage indexed by layer and head. This helps
// accessing the tiles even though different layers may have a different
// number of tiles in storage. All pointers point into compact_kv_cache.
// To access the tiles of (layer_idx, head_idx), index the array with
// layer_idx * num_kv_heads + kv_head_idx.
// Or use GetPointers function.
// The returned MatPtr will have one tile per row. The number of rows for
// global layers is max_seq_len/kTileSize. For local layers it is slightly
// more than attention_window_size[layer_idx] / kTileSize. For local layers, a
// given token_idx is in row (token_idx / kTileSize) %
// kv_head_ptrs[...].Rows().
std::vector<MatPtr> kv_head_ptrs;
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
KVCachePtr ToPtr() {
return KVCachePtr{
.kv_cache = kv_cache,
.cache = this,
};
}
@ -67,9 +139,17 @@ struct KVCache {
};
inline size_t KVCachePtr::SeqLen() const {
if (IsTiled()) {
return cache->tiled_seq_len.value();
}
return kv_cache.Rows();
}
inline bool KVCachePtr::IsTiled() const {
// MPU code create a KVCachePtr without kv_cache.
return cache != nullptr && cache->tiled_seq_len.has_value();
}
// Convenience function to create views into KVCaches.
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);

660
gemma/tiled_attention.cc Normal file
View File

@ -0,0 +1,660 @@
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstring>
#include <iostream>
#include <limits>
#include <utility>
#include <vector>
#include "compression/compress.h"
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "ops/matmul.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// Note: HWY_DISABLED_TARGETS needs to be defined the same everywhere.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "util/basics.h"
#include "util/mat.h"
#include "util/threading_context.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/tiled_attention.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/attention.h"
#include "gemma/flash_attention.h" // includes highway.h
#include "gemma/gemma-inl.h"
#include "ops/ops-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
static HWY_INLINE void MergeOnlineSoftmax(
const float* HWY_RESTRICT other_att_out, const float other_softmax_max,
const float other_softmax_d, int qkv_dim,
float* HWY_RESTRICT accumulator_att_out, float& accumulator_softmax_max,
float& accumulator_softmax_d) {
if (other_softmax_d == 0.0f) {
return;
}
if (accumulator_softmax_d == 0.0f) {
memcpy(accumulator_att_out, other_att_out,
qkv_dim * sizeof(*accumulator_att_out));
accumulator_softmax_max = other_softmax_max;
accumulator_softmax_d = other_softmax_d;
return;
}
const float m_new = std::max(accumulator_softmax_max, other_softmax_max);
const float exp_l = std::exp(accumulator_softmax_max - m_new);
const float exp_r = std::exp(other_softmax_max - m_new);
const float d_new = accumulator_softmax_d * exp_l + other_softmax_d * exp_r;
const float d_new_inv = 1.0f / d_new;
const float c1 = accumulator_softmax_d * exp_l * d_new_inv;
const float c2 = other_softmax_d * exp_r * d_new_inv;
MulByConst(c1, accumulator_att_out, qkv_dim);
MulByConstAndAdd(c2, other_att_out, accumulator_att_out, qkv_dim);
accumulator_softmax_max = m_new;
accumulator_softmax_d = d_new;
}
// Forked from ComputeQKV. But it stores the K/V in the tiled format
// KV_T is type stored in the KV cache (typically float or BF16).
template <typename KV_T>
static HWY_INLINE void ComputeQKVTransposedTile(
size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionImpl attention_impl, AttentionActivationsPtrs& activations,
const QBatch& qbatch, const int flags, MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKVTiled");
const hwy::Divisor div_qbatch(qbatch.Size());
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
const LayerConfig& layer_config = layer.layer_config;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads;
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
// This computes Q and stores it in activations.q.
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
// This computes Q and stores it in activations.q.
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
/*add=*/nullptr, env, activations.q);
// Compute the combined KV output from pre_att_rms_out.
// The output shape is [num_interleaved, kv_heads * 2 * qkv_dim].
const size_t kv_out_cols = kv_heads * 2 * qkv_dim;
hwy::AlignedFreeUniquePtr<float[]> kv_out_mem =
hwy::AllocateAligned<float>(num_interleaved * kv_out_cols);
float* kv_out_data = kv_out_mem.get();
MatPtrT<float> kv_out_mat("kv_out", Extents2D(num_interleaved, kv_out_cols));
kv_out_mat.SetPtr(kv_out_data, kv_out_cols);
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_out_mat);
// Apply positional encodings and store K/V in tiled format.
hwy::Divisor div_kv_heads(kv_heads);
hn::ScalableTag<float> df;
static hwy::Divisor tile_size_divisor(KVCache::kTileSize);
ParallelFor(
Parallelism::kFlat, kv_heads * qbatch.Size(), env.ctx,
/*cluster_idx=*/0, Callers::kAttComputeQKV,
[&](size_t task, size_t worker) HWY_ATTR {
const size_t kv_head = div_kv_heads.Remainder(task);
const size_t query_idx = div_kv_heads.Divide(task);
CompressPerThread tls;
size_t current_token_idx = 0;
float* k_tile_vec = activations.k_tile_vec.Row(task);
float* v_tile_vec = activations.v_tile_vec.Row(task);
HWY_ALIGN float k_f32[kMaxQKVDim];
const size_t start_pos = qbatch.Pos(query_idx);
const bool is_global_layer =
activations.config.IsGlobalLayer(layer_idx);
std::vector<MatPtr> kv_ptrs =
qbatch.KV(query_idx).cache->GetPointers(
layer_idx, kv_head, kv_heads, start_pos, is_global_layer);
size_t tile_offset = 0;
if (!is_global_layer) {
tile_offset = start_pos / KVCache::kTileSize;
}
while (current_token_idx < num_tokens) {
const size_t pos = start_pos + current_token_idx;
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
const size_t tile_idx = tile_size_divisor.Divide(pos_mod);
const size_t relative_tile_idx = tile_idx - tile_offset;
KV_T* tile_ptr;
int kv_ptr_idx = 0;
size_t absolute_rows = 0;
while (absolute_rows + kv_ptrs[kv_ptr_idx].Rows() <=
relative_tile_idx) {
absolute_rows += kv_ptrs[kv_ptr_idx].Rows();
kv_ptr_idx++;
}
tile_ptr = HWY_RCAST_ALIGNED(
KV_T*,
kv_ptrs[kv_ptr_idx].RowBytes(relative_tile_idx - absolute_rows));
PackedSpan<KV_T> tile_packed_span{tile_ptr,
2 * qkv_dim * KVCache::kTileSize};
DecompressAndZeroPad(df, tile_packed_span, 0, k_tile_vec,
qkv_dim * KVCache::kTileSize);
DecompressAndZeroPad(df, tile_packed_span,
qkv_dim * KVCache::kTileSize, v_tile_vec,
qkv_dim * KVCache::kTileSize);
size_t token_in_tile_idx = current_token_idx;
while (token_in_tile_idx < num_tokens) {
const size_t current_pos =
qbatch.Pos(query_idx) + token_in_tile_idx;
const size_t current_pos_mod =
activations.div_seq_len.Remainder(current_pos);
if (tile_size_divisor.Divide(current_pos_mod) != tile_idx) {
break; // Moved to next tile
}
const float* kv_row =
kv_out_data +
(token_in_tile_idx * qbatch.Size() + query_idx) * kv_out_cols;
const float* k_ptr = kv_row + kv_head * 2 * qkv_dim;
const float* v_ptr = kv_row + kv_head * 2 * qkv_dim + qkv_dim;
hwy::CopyBytes(k_ptr, k_f32, qkv_dim * sizeof(float));
if (layer.key_norm_scale.HasPtr()) {
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, k_f32,
qkv_dim, env.ctx, worker);
});
}
PositionalEncodingQK(
k_f32, layer_idx, activations, env.ctx, worker,
current_pos ,
/*mul=*/1.0f);
const size_t in_tile_idx = current_pos_mod % KVCache::kTileSize;
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
const int in_tile_idx_mod_2 = in_tile_idx % 2;
for (int dim = 0; dim < qkv_dim; dim += 2) {
const int dim_mod_2 = dim % 2;
// Pack k's in pairs in preparation for BF16 dot product.
// See flash_attention.cc
// QDotKTilexUpTo4TransposedKDoubleWidthBF16
k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize +
in_tile_idx * 2] = k_f32[dim];
k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize +
in_tile_idx * 2 + 1] = k_f32[dim + 1];
// Pack v's in pairs
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
dim * 2 + in_tile_idx_mod_2] = v_ptr[dim];
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
(dim + 1) * 2 + in_tile_idx_mod_2] = v_ptr[dim + 1];
}
} else {
for (int i = 0; i < qkv_dim; ++i) {
k_tile_vec[i * KVCache::kTileSize + in_tile_idx] = k_f32[i];
}
Compress(v_ptr, qkv_dim, tls, tile_packed_span,
qkv_dim * (KVCache::kTileSize + in_tile_idx));
}
token_in_tile_idx++;
}
Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls,
tile_packed_span, 0);
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls,
tile_packed_span, qkv_dim * KVCache::kTileSize);
}
current_token_idx = token_in_tile_idx;
}
});
}
// TODO: optimize with gathers
// This format might change in the future, when kernel will be updated to
// support more than 8 queries.
// Input (num_queries, qkv_dim)
// Output (qkv_dim, num_queries)
void TransposeQ(const MatPtrT<float>& queries,
hwy::Span<float> transposed_queries_span) {
const size_t qkv_dim = queries.Cols();
const size_t num_queries = queries.Rows();
HWY_ASSERT(transposed_queries_span.size() == num_queries * qkv_dim);
for (size_t i = 0; i < qkv_dim; i++) {
for (size_t j = 0; j < num_queries; ++j) {
transposed_queries_span[i * num_queries + j] = queries.Row(j)[i];
}
}
}
// Transposes queries
// Input: vector of pointers to subsequent queries. (allows for arbitrary
// strides)
// qkv_dim: dimension of query
// allocator: aligned allocator to use for temporary storage
//
// Output: Pointer to contiguous memory with shape (qkv_dim,
// queries.size())
void TransposeStridedQueries(
hwy::Span<float*> queries, int qkv_dim,
hwy::Span<float> transposed_queries) {
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
using DI = hn::ScalableTag<int32_t>;
const DI di;
using VI = hn::Vec<DI>;
const size_t lanes = hn::Lanes(df);
const size_t num_queries = queries.size();
const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, lanes);
std::vector<int32_t, hwy::AlignedAllocator<int32_t>> query_offsets(
num_queries_rounded_up);
for (size_t i = 0; i < num_queries; ++i) {
query_offsets[i] = queries[i] - queries[0];
}
for (size_t i = num_queries; i < num_queries_rounded_up; ++i) {
// last offset is the same so gather doesn't read out of bounds
query_offsets[i] = query_offsets[num_queries - 1];
}
for (size_t i = 0; i < qkv_dim; i++) {
size_t j = 0;
if (num_queries >= lanes) {
for (; j <= num_queries-lanes; j += lanes) {
const VI offsets = hn::LoadU(di, query_offsets.data() + j);
VF x = hn::GatherIndex(df, queries[0] + i, offsets);
hn::StoreU(x, df, transposed_queries.data() + i * num_queries + j);
}
}
if (j < num_queries) {
const VI offsets = hn::LoadU(di, query_offsets.data() + j);
VF x = hn::GatherIndex(df, queries[0] + i, offsets);
hn::StoreN(x, df, transposed_queries.data() + i * num_queries + j,
num_queries - j);
}
}
}
std::pair<AlignedFloatVector, std::vector<float*>> TransposeQueriesToGroupsOf4(
hwy::Span<float*> queries_ptrs, int qkv_dim) {
int num_queries = queries_ptrs.size();
int num_groups = hwy::DivCeil(num_queries, 4);
AlignedFloatVector transposed_queries(num_groups * 4 * qkv_dim);
std::vector<float*> transposed_queries_ptrs;
for (int group_idx = 0; group_idx < num_groups; ++group_idx){
int group_size = std::min(4, num_queries - group_idx * 4);
transposed_queries_ptrs.push_back(transposed_queries.data() +
group_idx * qkv_dim * 4);
TransposeStridedQueries(
hwy::Span<float*>(queries_ptrs.data() + group_idx * 4,
group_size),
qkv_dim,
hwy::Span<float>(transposed_queries_ptrs.back(), qkv_dim * group_size));
}
return std::make_pair(std::move(transposed_queries),
std::move(transposed_queries_ptrs));
}
std::pair<AlignedBF16Vector, std::vector<BF16*>>
TransposeTransposedQueriesAndPackIntoBF16(hwy::Span<float*> queries_ptrs,
int qkv_dim, int num_queries) {
constexpr int kMaxGroupSize = 4;
int num_groups = queries_ptrs.size();
AlignedBF16Vector transposed_queries(num_groups * kMaxGroupSize * qkv_dim);
std::vector<BF16*> transposed_queries_ptrs;
transposed_queries_ptrs.reserve(num_groups);
for (int group_idx = 0; group_idx < num_groups; ++group_idx) {
int group_size =
std::min(kMaxGroupSize, num_queries - group_idx * kMaxGroupSize);
transposed_queries_ptrs.push_back(transposed_queries.data() +
group_idx * qkv_dim * kMaxGroupSize);
for (int dim_idx = 0; dim_idx < qkv_dim; dim_idx += 2) {
for (int query_idx = 0; query_idx < group_size; ++query_idx) {
transposed_queries_ptrs.back()[dim_idx * group_size + query_idx * 2] =
hwy::ConvertScalarTo<BF16>(
queries_ptrs[group_idx][dim_idx * group_size + query_idx]);
transposed_queries_ptrs
.back()[dim_idx * group_size + query_idx * 2 + 1] =
hwy::ConvertScalarTo<BF16>(
queries_ptrs[group_idx]
[(dim_idx + 1) * group_size + query_idx]);
}
}
}
return std::make_pair(std::move(transposed_queries),
std::move(transposed_queries_ptrs));
}
template <typename T>
static HWY_INLINE void MaybeResizeMatStorage(MatStorageT<T>& mat_storage,
int rows, int cols,
const char* name,
const Allocator& allocator) {
if (mat_storage.Rows() != rows || mat_storage.Cols() != cols) {
mat_storage = MatStorageT<T>(name, Extents2D(rows, cols), allocator,
MatPadding::kOdd);
}
}
// clang-format off
// Schedules TiledFlashAttention for all heads, tokens and batch.
// Returns partial results in the same order as queries in `activations.q`.
// Might not work yet for prefix lm.
// To help understanding how to use this function below is description of how
// parameters are used:
//
// attention_impl - Used to determine attention kernel to use.
// num_query_tokens - number of tokens/timesteps in processed in a single batch
// it will influence how many queries kvs are evaluated against.
// num_kv_tokens - number of tokens/timesteps in kv cache
// layer_idx - layer index
// layer - used to get kv_heads, heads, qkv_dim
// activations - reads: activations.q queries, att_cap, IsGlobalLayer
// qbatch - kv cache, Pos / EndPrefix
// ctx - threading context
// clang-format on
void LocalAttentionForAllHeadsTokensAndBatch(
AttentionImpl attention_impl, const size_t num_query_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, QBatch& qbatch,
ThreadingContext& ctx) {
const size_t heads_per_kv_head =
layer.layer_config.heads / layer.layer_config.kv_heads;
int core_count = ctx.pools.MaxWorkers();
int task_multiplier = 1;
while (qbatch.Size() * layer.layer_config.kv_heads * task_multiplier <
core_count * 2) {
task_multiplier++;
}
// Finding the smallest context we need to attend to avoid unnecessary
// overhead when sub-splitting doesn't make sense. This check overestimates
// context sizes because it ignores [local] layer sizes and explicit
// qbatch.Prefix settings.
size_t min_pos = qbatch.Pos(0);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
min_pos = std::min(min_pos, qbatch.Pos(qi));
}
if (min_pos / task_multiplier < num_query_tokens) {
// In case where min_pos / task_multiplier < num_tokens
// To make sure we don't over count tokens or read out of bounds code
// requires quite a bit more involved logic.
// Also there is not much point to splitting the work into more tasks, when
// amount of work is small.
task_multiplier = 1;
}
[[maybe_unused]] int num_tasks = qbatch.Size() * layer.layer_config.kv_heads;
[[maybe_unused]] int num_sub_tasks =
qbatch.Size() * layer.layer_config.kv_heads * task_multiplier;
HWY_DASSERT_M(activations.q.Rows() == num_query_tokens * qbatch.Size(),
"qbatch size mismatch");
int qkv_dim = layer.layer_config.qkv_dim;
// sizes of all should be in sync
if (num_sub_tasks > activations.sub_task_att_out->size()) {
activations.sub_task_att_out->resize(num_sub_tasks);
activations.sub_task_exp_denominator_sums->resize(num_sub_tasks);
activations.sub_task_max_logits->resize(num_sub_tasks);
}
std::vector<int> skip_sub_task(num_sub_tasks, 0);
// This loop parallelizes over qbatch, kv_head and substrings of context
// tokens. Each parallel invocation handles all query tokens of the given
// qbatch.
ParallelFor(
Parallelism::kHierarchical, num_sub_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t task_idx, size_t worker) HWY_ATTR {
size_t main_task_idx = task_idx / task_multiplier;
size_t sub_task_idx = task_idx % task_multiplier;
size_t current_qbatch_idx =
main_task_idx / layer.layer_config.kv_heads;
size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads;
// First and last context token we will attend to.
size_t global_start_context_pos = StartPos(
qbatch.Pos(current_qbatch_idx), activations.config, layer_idx);
// Keep in mind this is overestimation because some timesteps might not
// need all tokens due to causal mask.
// We will use it to determine how to divide work between sub tasks
// and make sure PrefixEnd is taken into account
size_t start_context_pos = global_start_context_pos;
size_t last_context_pos =
qbatch.Pos(current_qbatch_idx) + num_query_tokens - 1;
// In some models, context is limited to some prefix - make sure we take
// that into account.
const size_t prefix_end = qbatch.PrefixEnd(current_qbatch_idx);
if (prefix_end > 0 && prefix_end - 1 > last_context_pos) {
last_context_pos = prefix_end - 1;
}
size_t total_num_context_tokens =
last_context_pos - start_context_pos + 1;
size_t context_tokens_per_sub_task =
hwy::DivCeil(total_num_context_tokens, task_multiplier);
// Restrict tokens to attend to the substring of context tokens that
// this subtask is responsible for.
start_context_pos =
start_context_pos + context_tokens_per_sub_task * sub_task_idx;
if (start_context_pos > last_context_pos) {
skip_sub_task[task_idx] = 1;
return;
}
last_context_pos =
std::min(last_context_pos,
start_context_pos + context_tokens_per_sub_task - 1);
// pre-initialize memory [to avoid racy resizes laters].
int num_queries = num_query_tokens * heads_per_kv_head;
std::vector<float*> queries_ptrs;
queries_ptrs.reserve(num_queries);
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
for (int q_head_idx = 0; q_head_idx < heads_per_kv_head;
++q_head_idx) {
queries_ptrs.push_back(
activations.q.Row(token_idx * qbatch.Size() +
current_qbatch_idx) +
(kv_head_idx * heads_per_kv_head + q_head_idx) * qkv_dim);
}
}
hwy::Span<float*> queries_ptrs_span(queries_ptrs.data(),
queries_ptrs.size());
auto [transposed_queries, transposed_queries_ptrs] =
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
MatStorageT<float>& att_out =
activations.sub_task_att_out->at(task_idx);
AlignedFloatVector& exp_denominator_sums =
activations.sub_task_exp_denominator_sums->at(task_idx);
AlignedFloatVector& max_logits =
activations.sub_task_max_logits->at(task_idx);
MaybeResizeMatStorage(att_out, num_queries, qkv_dim, "att_out",
ctx.allocator);
for (int i = 0; i < num_queries; ++i) {
hwy::ZeroBytes(att_out.Row(i),
att_out.Cols() * sizeof(decltype(att_out.Row(i)[0])));
}
int num_queries_rounded_to_8 = hwy::RoundUpTo(num_queries, 8);
exp_denominator_sums.resize(num_queries_rounded_to_8);
max_logits.resize(num_queries_rounded_to_8);
for (int i = 0; i < num_queries_rounded_to_8; ++i) {
exp_denominator_sums[i] = 0.0f;
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
}
// Get pointers to the KVCache tiles, starting at global_start_pos
// Returns multiple matrices for non-contiguous memory, for example as a
// result of the wraparound in local layers.
std::vector<MatPtr> kv_ptrs =
qbatch.KV(current_qbatch_idx)
.cache->GetPointers(
layer_idx, kv_head_idx, layer.layer_config.kv_heads,
global_start_context_pos,
activations.config.IsGlobalLayer(layer_idx));
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
// Position of the first token in the first tile whose pointer was
// returned above. Allows for handling of token positions relative to
// the KV tiles returned above.
size_t rounded_down_global_start_pos =
hwy::RoundDownTo(global_start_context_pos, KVCache::kTileSize);
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
int64_t global_query_pos =
qbatch.Pos(current_qbatch_idx) + token_idx;
// Intersect context to attend to for this specific query token
// to the context tokens of the current subtask.
int64_t query_last_context_pos = std::min(
static_cast<int64_t>(last_context_pos), global_query_pos);
// This max is to not go into negative values, for the same reason we
// use int64_t and not size_t here.
int64_t query_start_context_pos = std::max(
global_query_pos -
static_cast<int64_t>(
activations.config.attention_window_sizes[layer_idx]) +
1,
static_cast<int64_t>(start_context_pos));
// Turn token position into KV-tile relative token positions.
query_last_context_pos -= rounded_down_global_start_pos;
query_start_context_pos -= rounded_down_global_start_pos;
for (int q_head_idx = 0; q_head_idx < heads_per_kv_head;
++q_head_idx) {
start_pos_per_query.push_back(query_start_context_pos);
last_pos_per_query.push_back(query_last_context_pos);
}
}
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
// pack transposed queries into BF16
hwy::Span<float*> queries_span(transposed_queries_ptrs.data(),
transposed_queries_ptrs.size());
auto [_, transposed_queries_ptrs_bf16] =
TransposeTransposedQueriesAndPackIntoBF16(queries_span, qkv_dim,
num_queries);
hwy::Span<const BF16*> queries_span_bf16(
const_cast<const BF16**>(transposed_queries_ptrs_bf16.data()),
transposed_queries_ptrs_bf16.size());
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kv_ptrs, num_queries, queries_span_bf16,
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query),
activations.config.att_cap, att_out, exp_denominator_sums.data(),
max_logits.data());
} else {
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
kv_ptrs, num_queries,
hwy::Span<const float*>(
const_cast<const float**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size()),
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query),
activations.config.att_cap, att_out, exp_denominator_sums.data(),
max_logits.data());
}
});
// This loop takes results from separate subtasks (subsequence of kv) and
// merges them into single att_out over whole kv sequence.
ParallelFor(
Parallelism::kFlat, num_tasks, ctx,
/*cluster_idx=*/0, Callers::kFlashAttention,
[&](size_t main_task_idx, size_t worker) HWY_ATTR {
size_t current_qbatch_idx = main_task_idx / layer.layer_config.kv_heads;
size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads;
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
for (int head_in_group_idx = 0; head_in_group_idx < heads_per_kv_head;
++head_in_group_idx) {
const size_t batch_index =
current_qbatch_idx * num_query_tokens + token_idx;
const size_t q_head_idx =
kv_head_idx * heads_per_kv_head + head_in_group_idx;
const size_t att_out_row_idx =
token_idx * heads_per_kv_head + head_in_group_idx;
const size_t activations_att_out_start_idx = q_head_idx * qkv_dim;
auto& att_out_0 = activations.sub_task_att_out->at(
main_task_idx * task_multiplier + 0);
auto& exp_denominator_sums_0 =
activations.sub_task_exp_denominator_sums->at(
main_task_idx * task_multiplier + 0);
auto& max_logits_0 = activations.sub_task_max_logits->at(
main_task_idx * task_multiplier + 0);
hwy::CopyBytes(att_out_0.Row(att_out_row_idx),
activations.att_out.Row(batch_index) +
activations_att_out_start_idx,
qkv_dim * sizeof(float));
activations.softmax_d.Row(batch_index)[q_head_idx] =
exp_denominator_sums_0[token_idx * heads_per_kv_head +
head_in_group_idx];
activations.softmax_max.Row(batch_index)[q_head_idx] =
max_logits_0[token_idx * heads_per_kv_head + head_in_group_idx];
for (int sub_task_idx = 1; sub_task_idx < task_multiplier;
++sub_task_idx) {
int task_idx = main_task_idx * task_multiplier + sub_task_idx;
if (skip_sub_task[task_idx] == 1) {
continue;
}
auto& att_out = activations.sub_task_att_out->at(task_idx);
auto& exp_denominator_sums =
activations.sub_task_exp_denominator_sums->at(task_idx);
auto& max_logits = activations.sub_task_max_logits->at(task_idx);
MergeOnlineSoftmax(
att_out.Row(att_out_row_idx),
max_logits[token_idx * heads_per_kv_head + head_in_group_idx],
exp_denominator_sums[token_idx * heads_per_kv_head +
head_in_group_idx],
qkv_dim,
activations.att_out.Row(batch_index) +
activations_att_out_start_idx,
activations.softmax_max.Row(batch_index)[q_head_idx],
activations.softmax_d.Row(batch_index)[q_head_idx]);
}
}
}
});
}
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
AttentionActivationsPtrs& activations, QBatch& qbatch,
MatMulEnv& env, int flags) {
static const auto zone = env.ctx.profiler.AddZone(
"Gen.TiledAttention", hwy::ProfilerFlags::kInclusive);
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
const LayerConfig& layer_config = layer.layer_config;
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
(void)layer_config; // only used in HWY_DASSERT
if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) {
ComputeQKVTransposedTile<BF16>(num_tokens, layer_idx, layer, attention_impl,
activations, qbatch, flags, env);
} else {
ComputeQKVTransposedTile<KV_t>(num_tokens, layer_idx, layer, attention_impl,
activations, qbatch, flags, env);
}
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
layer.query_norm_scale, layer_idx, activations,
env.ctx);
LocalAttentionForAllHeadsTokensAndBatch(attention_impl, num_tokens, layer_idx,
layer, activations, qbatch, env.ctx);
SumHeads(layer, activations, env);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

42
gemma/tiled_attention.h Normal file
View File

@ -0,0 +1,42 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_
#include <stddef.h>
#include <cstddef>
#include <utility>
#include <vector>
#include "gemma/gemma.h"
#include "util/allocator.h"
#include "hwy/aligned_allocator.h"
#include "hwy/highway.h"
namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \
hwy::Span<float> transposed_queries); \
void LocalAttentionForAllHeadsTokensAndBatch( \
AttentionImpl attention_impl, const size_t num_tokens, \
const size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE
// Function declarations for each SIMD target. Allows direct call from the
// per-target namespace. We may later replace this with dynamic dispatch if
// the overhead is acceptable.
HWY_VISIT_TARGETS(GEMMA_DECL_TILED_ATTENTION)
#undef GEMMA_DECL_TILED_ATTENTION
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_

View File

@ -0,0 +1,749 @@
#include <stddef.h>
#include <iostream>
#include <memory>
#include <tuple>
#include <vector>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "util/threading_context.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/tiled_attention_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/tiled_attention.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
using ::testing::FloatNear;
using ::testing::Pointwise;
struct AttentionTestEnv {
AttentionTestEnv(
int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads,
int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx,
int layers_total, int qbatch_size, AttentionImpl attention_impl,
)
: ctx(threading_args), env(ctx) {
layer_config.heads = num_heads;
layer_config.kv_heads = num_kv_heads;
layer_config.qkv_dim = qkv_dim;
layer_config.model_dim = qkv_dim * num_heads;
model_config.attention_window_sizes = {
static_cast<uint32_t>(attention_window_size)};
model_config.att_cap = att_cap;
model_config.max_seq_len = kv_seq_len;
model_config.num_layers = layers_total;
model_config.model_dim = layer_config.model_dim;
model_config.vocab_size = 1; // not vit
for (int i = 0; i < model_config.num_layers; ++i) {
model_config.layer_configs.push_back(layer_config);
}
tensor_info_registry = std::make_unique<TensorInfoRegistry>(model_config);
layer = std::make_unique<LayerWeightsPtrs>(layer_idx, layer_config,
*tensor_info_registry);
runtime_config.attention_impl = attention_impl;
inference_args.seq_len = kv_seq_len;
all_queries.Reserve(qbatch_size);
kv_caches.reserve(qbatch_size);
for (int q = 0; q < qbatch_size; ++q) {
kv_caches.emplace_back(model_config, inference_args, runtime_config,
ctx.allocator);
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
MatPtrT<BF16> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
for (int i = 0; i < compact_kv_cache.Rows(); ++i) {
for (int j = 0; j < compact_kv_cache.Cols(); ++j) {
BF16 val = hwy::ConvertScalarTo<BF16>(hwy::Unpredictable1() *
0.01f * (i + j + 1));
// split j into if k/v
if (j < qkv_dim * gcpp::KVCache::kTileSize) {
// split j into dim and in tile offset
const int dim = j / gcpp::KVCache::kTileSize;
const int in_tile_offset = j % gcpp::KVCache::kTileSize;
const int dim_mod_2 = dim % 2;
compact_kv_cache.Row(
i)[(dim - dim_mod_2) * gcpp::KVCache::kTileSize +
in_tile_offset * 2 + dim_mod_2] = val;
} else {
const int in_tile_offset = j / qkv_dim;
const int dim = j % qkv_dim;
const int in_tile_offset_mod_2 = in_tile_offset % 2;
compact_kv_cache.Row(
i)[(in_tile_offset - in_tile_offset_mod_2) * qkv_dim +
dim * 2 + in_tile_offset_mod_2] = val;
}
}
}
} else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
MatPtrT<float> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
FillMatPtrT(compact_kv_cache);
} else {
FillMatPtrT(kv_caches.back().kv_cache);
}
all_queries.Append({
.prompt = PromptTokens({1, 2, 3}),
.mutable_pos = static_cast<size_t>(last_pos),
.initial_pos = 0,
.prefix_end = 0,
.kv_cache = kv_caches.back().ToPtr(),
});
}
activations = std::make_unique<Activations>(runtime_config, model_config,
qbatch_size * num_tokens,
kv_seq_len, ctx, env.row_ptrs);
qbatch =
std::make_unique<QBatch>(/*start_pos=*/0, qbatch_size, all_queries);
}
void SetupWeights() {
int model_dim = layer_config.model_dim;
int qkv_dim = layer_config.qkv_dim;
int num_heads = layer_config.heads;
int num_kv_heads = layer_config.kv_heads;
qkv1_w_storage =
MatStorageT<float>("qkv1", Extents2D(model_dim, qkv_dim * num_heads),
ctx.allocator, MatPadding::kPacked);
qkv2_w_storage = MatStorageT<float>(
"qkv2", Extents2D(model_dim, num_kv_heads * 2 * qkv_dim), ctx.allocator,
MatPadding::kPacked);
wo_w_storage = MatStorageT<float>("wo", Extents2D(model_dim, model_dim),
ctx.allocator, MatPadding::kPacked);
FillMatPtrT(wo_w_storage);
layer->att_weights = wo_w_storage;
FillMatPtrT(qkv1_w_storage);
FillMatPtrT(qkv2_w_storage);
layer->qkv_einsum_w1 = qkv1_w_storage;
layer->qkv_einsum_w2 = qkv2_w_storage;
query_norm_scale = MatStorageT<float>("query_norm", qkv_dim, ctx.allocator);
FillMatPtrT(query_norm_scale);
layer->query_norm_scale = query_norm_scale;
key_norm_scale = MatStorageT<float>("key_norm", qkv_dim, ctx.allocator);
FillMatPtrT(key_norm_scale);
layer->key_norm_scale = key_norm_scale;
}
AttentionTestEnv(const AttentionTestEnv&) = delete;
AttentionTestEnv& operator=(const AttentionTestEnv&) = delete;
AttentionTestEnv(AttentionTestEnv&&) = delete;
AttentionTestEnv& operator=(AttentionTestEnv&&) = delete;
ThreadingArgs threading_args;
ThreadingContext ctx;
MatMulEnv env;
LayerConfig layer_config;
ModelConfig model_config;
std::unique_ptr<TensorInfoRegistry> tensor_info_registry;
std::unique_ptr<LayerWeightsPtrs> layer;
RuntimeConfig runtime_config;
InferenceArgs inference_args;
AllQueries all_queries;
std::vector<KVCache> kv_caches;
std::unique_ptr<Activations> activations;
std::unique_ptr<QBatch> qbatch;
// Weights storage for later tests
MatStorageT<float> qkv1_w_storage;
MatStorageT<float> qkv2_w_storage;
MatStorageT<float> wo_w_storage;
MatStorageT<float> query_norm_scale;
MatStorageT<float> key_norm_scale;
};
void TestTransposeStridedQueries() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
int qkv_dim = 64;
int num_queries = 24;
AlignedPtr<float[]> input_queries =
ctx.allocator.Alloc<float>(qkv_dim * num_queries);
AlignedPtr<float[]> output_queries =
ctx.allocator.Alloc<float>(qkv_dim * num_queries);
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
input_queries[i * qkv_dim + j] = i * qkv_dim + j;
}
}
std::vector<float*> queries;
for (int i = 0; i < num_queries; ++i) {
queries.push_back(input_queries.get() + i * qkv_dim);
}
hwy::Span<float*> queries_span(queries.data(), queries.size());
TransposeStridedQueries(
queries_span, qkv_dim,
hwy::Span<float>(output_queries.get(), qkv_dim * num_queries));
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_EQ(output_queries[j * num_queries + i],
input_queries[i * qkv_dim + j])
<< "i=" << i << " j=" << j;
}
}
}
void TestLocalAttentionForAllHeadsTokensAndBatch() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 2;
int num_tokens = 2;
int last_pos = 62; // so token 0 will have 63 and token 1 will have 64 tokens
// to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl);
FillMatPtrT(test_env.activations->attention.q);
LocalAttentionForAllHeadsTokensAndBatch(
attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch, test_env.ctx);
// print states;
std::vector<float> exp_denominator_sums_gold = {63, 63, 64, 64,
63, 63, 64, 64};
std::vector<float> max_logits_gold = {10, 10, 10, 10, 10, 10, 10, 10};
std::vector<float> att_out_gold = {
30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275,
30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075,
30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875,
30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675,
30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475,
30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275,
30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075,
30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875,
30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475,
30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275,
30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075,
30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875,
30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675,
30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475,
30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275,
30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075,
30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485,
30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565,
30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645,
30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725,
30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805,
30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885,
30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965,
30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045,
30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505,
30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585,
30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665,
30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745,
30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825,
30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905,
30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985,
30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065,
30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275,
30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075,
30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875,
30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675,
30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475,
30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275,
30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075,
30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875,
30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475,
30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275,
30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075,
30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875,
30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675,
30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475,
30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275,
30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075,
30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485,
30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565,
30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645,
30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725,
30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805,
30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885,
30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965,
30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045,
30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505,
30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585,
30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665,
30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745,
30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825,
30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905,
30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985,
30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065,
};
const int group_size = num_heads / num_kv_heads;
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int q_batch_idx = 0; q_batch_idx < qbatch_size; ++q_batch_idx) {
int b = token_idx * qbatch_size + q_batch_idx;
EXPECT_THAT(
absl::MakeSpan(test_env.activations->attention.softmax_d.Row(b),
num_heads),
Pointwise(FloatNear(1e-3f), absl::MakeSpan(exp_denominator_sums_gold)
.subspan(b * num_heads, num_heads)));
EXPECT_THAT(
absl::MakeSpan(test_env.activations->attention.softmax_max.Row(b),
num_heads),
Pointwise(FloatNear(1e-3f), absl::MakeSpan(max_logits_gold)
.subspan(b * num_heads, num_heads)));
for (int kv_h = 0; kv_h < num_kv_heads; ++kv_h) {
for (int g = 0; g < group_size; ++g) {
const int q_h = kv_h * group_size + g;
size_t expected_q_idx = b * num_heads + q_h;
EXPECT_THAT(
absl::MakeSpan(test_env.activations->attention.att_out.Row(b) +
q_h * qkv_dim,
qkv_dim),
Pointwise(FloatNear(1e-3f),
absl::MakeSpan(att_out_gold)
.subspan(expected_q_idx * qkv_dim, qkv_dim)));
}
}
}
}
}
const std::vector<float> AttentionMultipleTokensAttentionGoldens = {
34.7414, 34.7717, 34.8022, 34.8327, 34.8631, 34.8936, 34.9241, 34.9545,
34.985, 35.0156, 35.046, 35.0765, 35.1068, 35.1373, 35.1678, 35.1982,
35.2286, 35.2592, 35.2895, 35.32, 35.3506, 35.381, 35.4115, 35.4421,
35.4725, 35.503, 35.5334, 35.5638, 35.5943, 35.6247, 35.6552, 35.6857,
35.7161, 35.7466, 35.7772, 35.8076, 35.8381, 35.8685, 35.8989, 35.9294,
35.9598, 35.9902, 36.0208, 36.0512, 36.0816, 36.1122, 36.1426, 36.1731,
36.2037, 36.2341, 36.2646, 36.295, 36.3254, 36.356, 36.3863, 36.4168,
36.4474, 36.4778, 36.5082, 36.5388, 36.5692, 36.5997, 36.6301, 36.6605,
34.6687, 34.6987, 34.7288, 34.759, 34.7891, 34.8192, 34.8495, 34.8795,
34.9097, 34.9399, 34.97, 35.0002, 35.0302, 35.0604, 35.0906, 35.1206,
35.1507, 35.181, 35.211, 35.2412, 35.2714, 35.3015, 35.3317, 35.3619,
35.3921, 35.4222, 35.4523, 35.4824, 35.5126, 35.5427, 35.5728, 35.603,
35.6331, 35.6633, 35.6935, 35.7236, 35.7538, 35.7838, 35.814, 35.8442,
35.8742, 35.9043, 35.9346, 35.9646, 35.9948, 36.025, 36.0551, 36.0853,
36.1155, 36.1456, 36.1759, 36.2059, 36.236, 36.2662, 36.2963, 36.3264,
36.3566, 36.3867, 36.4169, 36.4471, 36.4772, 36.5074, 36.5374, 36.5676,
37.0338, 37.0634, 37.0929, 37.1222, 37.1519, 37.1813, 37.2107, 37.2403,
37.2698, 37.2992, 37.3288, 37.3584, 37.3877, 37.4174, 37.447, 37.4764,
37.5056, 37.5352, 37.5646, 37.5938, 37.6234, 37.6528, 37.6821, 37.7117,
37.7412, 37.7705, 37.8001, 37.8295, 37.8589, 37.8885, 37.918, 37.9473,
37.977, 38.0065, 38.0358, 38.0655, 38.095, 38.1244, 38.1541, 38.1836,
38.213, 38.2422, 38.2718, 38.3012, 38.3305, 38.36, 38.3895, 38.4187,
38.4484, 38.4778, 38.5071, 38.5367, 38.5662, 38.5955, 38.6251, 38.6546,
38.6839, 38.7136, 38.7431, 38.7725, 38.8021, 38.8316, 38.861, 38.8907,
36.9872, 37.0167, 37.046, 37.0752, 37.1047, 37.1341, 37.1633, 37.1928,
37.2222, 37.2514, 37.2809, 37.3103, 37.3396, 37.3691, 37.3985, 37.4278,
37.4569, 37.4863, 37.5156, 37.5447, 37.5742, 37.6035, 37.6326, 37.6621,
37.6914, 37.7206, 37.7501, 37.7794, 37.8086, 37.8381, 37.8674, 37.8966,
37.9262, 37.9555, 37.9848, 38.0143, 38.0437, 38.0729, 38.1025, 38.1319,
38.1612, 38.1903, 38.2197, 38.249, 38.2781, 38.3075, 38.3368, 38.366,
38.3955, 38.4248, 38.4539, 38.4834, 38.5127, 38.5419, 38.5714, 38.6008,
38.63, 38.6595, 38.6889, 38.7181, 38.7477, 38.777, 38.8063, 38.8358,
39.0984, 39.1479, 39.1976, 39.2475, 39.297, 39.3468, 39.3967, 39.4463,
39.4961, 39.546, 39.5957, 39.6455, 39.695, 39.7447, 39.7946, 39.8441,
39.8939, 39.9438, 39.9934, 40.0431, 40.0931, 40.1427, 40.1925, 40.2425,
40.2921, 40.342, 40.3915, 40.4412, 40.4911, 40.5407, 40.5904, 40.6403,
40.6899, 40.7397, 40.7897, 40.8393, 40.8892, 40.9387, 40.9884, 41.0382,
41.0878, 41.1375, 41.1874, 41.237, 41.2868, 41.3367, 41.3863, 41.4361,
41.4861, 41.5358, 41.5856, 41.6351, 41.6849, 41.7347, 41.7843, 41.834,
41.884, 41.9336, 41.9834, 42.0333, 42.083, 42.1328, 42.1823, 42.232,
38.9699, 39.0188, 39.068, 39.1173, 39.1663, 39.2155, 39.2648, 39.3138,
39.3631, 39.4124, 39.4615, 39.5108, 39.5597, 39.6089, 39.6581, 39.7071,
39.7563, 39.8056, 39.8546, 39.9039, 39.9532, 40.0023, 40.0515, 40.1009,
40.15, 40.1993, 40.2483, 40.2974, 40.3467, 40.3957, 40.4449, 40.4942,
40.5433, 40.5925, 40.6419, 40.691, 40.7402, 40.7892, 40.8383, 40.8876,
40.9366, 40.9857, 41.035, 41.0841, 41.1333, 41.1826, 41.2317, 41.2809,
41.3303, 41.3794, 41.4287, 41.4777, 41.5268, 41.5761, 41.6251, 41.6743,
41.7237, 41.7727, 41.8219, 41.8713, 41.9204, 41.9697, 42.0186, 42.0677,
43.4945, 43.5425, 43.5902, 43.6376, 43.6856, 43.7334, 43.7808, 43.8289,
43.8766, 43.9241, 43.9722, 44.02, 44.0675, 44.1157, 44.1635, 44.2111,
44.2583, 44.3062, 44.3538, 44.4011, 44.449, 44.4966, 44.544, 44.5919,
44.6396, 44.6869, 44.735, 44.7826, 44.8301, 44.8781, 44.9258, 44.9733,
45.0213, 45.0691, 45.1166, 45.1647, 45.2125, 45.26, 45.3081, 45.356,
45.4035, 45.4508, 45.4987, 45.5462, 45.5936, 45.6415, 45.6891, 45.7364,
45.7844, 45.832, 45.8794, 45.9274, 45.9751, 46.0225, 46.0705, 46.1183,
46.1657, 46.2138, 46.2615, 46.309, 46.3571, 46.4049, 46.4525, 46.5006,
43.4125, 43.4603, 43.5077, 43.5549, 43.6027, 43.6502, 43.6974, 43.7453,
43.7928, 43.84, 43.8879, 43.9355, 43.9828, 44.0307, 44.0783, 44.1256,
44.1726, 44.2203, 44.2676, 44.3147, 44.3624, 44.4098, 44.4569, 44.5046,
44.552, 44.5992, 44.6469, 44.6944, 44.7416, 44.7894, 44.8369, 44.8841,
44.9319, 44.9795, 45.0267, 45.0746, 45.1222, 45.1694, 45.2173, 45.265,
45.3123, 45.3593, 45.407, 45.4543, 45.5014, 45.5491, 45.5965, 45.6436,
45.6913, 45.7387, 45.7859, 45.8336, 45.8811, 45.9283, 45.9761, 46.0236,
46.0708, 46.1186, 46.1661, 46.2134, 46.2613, 46.3088, 46.3561, 46.404,
34.7729, 34.8035, 34.8341, 34.8648, 34.8953, 34.9259, 34.9567, 34.9872,
35.0179, 35.0486, 35.0792, 35.1098, 35.1404, 35.171, 35.2016, 35.2322,
35.2628, 35.2935, 35.324, 35.3547, 35.3854, 35.416, 35.4466, 35.4774,
35.508, 35.5387, 35.5692, 35.5998, 35.6305, 35.661, 35.6916, 35.7224,
35.7529, 35.7836, 35.8143, 35.8449, 35.8755, 35.9061, 35.9367, 35.9674,
35.9979, 36.0285, 36.0592, 36.0898, 36.1204, 36.1511, 36.1817, 36.2123,
36.2431, 36.2737, 36.3044, 36.3349, 36.3655, 36.3962, 36.4267, 36.4574,
36.4881, 36.5186, 36.5493, 36.58, 36.6106, 36.6413, 36.6718, 36.7024,
34.6995, 34.7297, 34.76, 34.7904, 34.8206, 34.8509, 34.8813, 34.9115,
34.9418, 34.9722, 35.0025, 35.0328, 35.063, 35.0933, 35.1237, 35.1539,
35.1842, 35.2146, 35.2448, 35.2751, 35.3055, 35.3357, 35.3661, 35.3965,
35.4268, 35.4571, 35.4873, 35.5176, 35.548, 35.5782, 35.6085, 35.6389,
35.6691, 35.6994, 35.7298, 35.7601, 35.7904, 35.8206, 35.8509, 35.8813,
35.9115, 35.9418, 35.9721, 36.0024, 36.0327, 36.0631, 36.0933, 36.1237,
36.1541, 36.1843, 36.2147, 36.2449, 36.2752, 36.3056, 36.3358, 36.3661,
36.3965, 36.4267, 36.457, 36.4874, 36.5177, 36.548, 36.5782, 36.6085,
37.0829, 37.1127, 37.1423, 37.1717, 37.2015, 37.2312, 37.2607, 37.2905,
37.3201, 37.3496, 37.3795, 37.4091, 37.4386, 37.4685, 37.4982, 37.5277,
37.5571, 37.5868, 37.6164, 37.6458, 37.6755, 37.7051, 37.7346, 37.7643,
37.7939, 37.8234, 37.8531, 37.8827, 37.9122, 37.942, 37.9716, 38.0011,
38.0309, 38.0606, 38.0901, 38.1199, 38.1496, 38.1791, 38.209, 38.2387,
38.2682, 38.2976, 38.3273, 38.3569, 38.3863, 38.416, 38.4456, 38.475,
38.5048, 38.5344, 38.5638, 38.5936, 38.6232, 38.6527, 38.6825, 38.7121,
38.7416, 38.7714, 38.8011, 38.8306, 38.8604, 38.8901, 38.9196, 38.9494,
37.0359, 37.0655, 37.095, 37.1243, 37.154, 37.1835, 37.2129, 37.2425,
37.2721, 37.3014, 37.3311, 37.3607, 37.39, 37.4198, 37.4493, 37.4787,
37.508, 37.5376, 37.567, 37.5963, 37.6259, 37.6553, 37.6846, 37.7142,
37.7437, 37.773, 37.8027, 37.8322, 37.8615, 37.8911, 37.9207, 37.95,
37.9797, 38.0092, 38.0386, 38.0683, 38.0978, 38.1272, 38.1569, 38.1865,
38.2159, 38.2451, 38.2747, 38.3042, 38.3334, 38.363, 38.3925, 38.4218,
38.4514, 38.4809, 38.5102, 38.5398, 38.5693, 38.5986, 38.6283, 38.6578,
38.6872, 38.7168, 38.7464, 38.7757, 38.8054, 38.835, 38.8644, 38.8941,
39.1594, 39.2093, 39.2593, 39.3095, 39.3594, 39.4094, 39.4597, 39.5096,
39.5597, 39.61, 39.6599, 39.7101, 39.7599, 39.8099, 39.8601, 39.91,
39.96, 40.0102, 40.0601, 40.1102, 40.1605, 40.2104, 40.2605, 40.3108,
40.3608, 40.411, 40.4608, 40.5108, 40.561, 40.6109, 40.661, 40.7112,
40.7611, 40.8112, 40.8615, 40.9115, 40.9616, 41.0114, 41.0614, 41.1116,
41.1615, 41.2115, 41.2617, 41.3116, 41.3617, 41.412, 41.4619, 41.512,
41.5624, 41.6123, 41.6625, 41.7123, 41.7623, 41.8126, 41.8624, 41.9125,
41.9627, 42.0127, 42.0628, 42.113, 42.163, 42.2131, 42.263, 42.313,
39.0297, 39.079, 39.1284, 39.1781, 39.2274, 39.2769, 39.3265, 39.3759,
39.4254, 39.4751, 39.5245, 39.5741, 39.6233, 39.6727, 39.7224, 39.7716,
39.8211, 39.8708, 39.9201, 39.9696, 40.0193, 40.0686, 40.1182, 40.1679,
40.2173, 40.2669, 40.3162, 40.3656, 40.4153, 40.4646, 40.514, 40.5637,
40.6131, 40.6626, 40.7123, 40.7617, 40.8112, 40.8605, 40.9099, 40.9595,
41.0088, 41.0583, 41.1079, 41.1573, 41.2068, 41.2565, 41.3058, 41.3554,
41.4051, 41.4545, 41.5041, 41.5534, 41.6028, 41.6524, 41.7017, 41.7512,
41.8009, 41.8502, 41.8998, 41.9495, 41.9988, 42.0484, 42.0977, 42.1471,
43.5891, 43.6374, 43.6854, 43.7331, 43.7814, 43.8294, 43.8772, 43.9255,
43.9736, 44.0214, 44.0698, 44.1179, 44.1657, 44.2141, 44.2623, 44.3101,
44.3577, 44.4058, 44.4537, 44.5013, 44.5495, 44.5974, 44.6451, 44.6933,
44.7413, 44.7889, 44.8372, 44.8852, 44.9329, 44.9812, 45.0293, 45.077,
45.1254, 45.1734, 45.2212, 45.2696, 45.3177, 45.3655, 45.414, 45.4621,
45.5099, 45.5575, 45.6057, 45.6535, 45.7011, 45.7493, 45.7973, 45.8449,
45.8931, 45.9411, 45.9888, 46.037, 46.085, 46.1327, 46.1811, 46.2291,
46.2768, 46.3252, 46.3733, 46.421, 46.4694, 46.5175, 46.5653, 46.6138,
43.5064, 43.5544, 43.6022, 43.6497, 43.6978, 43.7456, 43.7931, 43.8412,
43.889, 43.9366, 43.9847, 44.0326, 44.0802, 44.1284, 44.1763, 44.2239,
44.2712, 44.3191, 44.3668, 44.4141, 44.4621, 44.5098, 44.5572, 44.6052,
44.6529, 44.7004, 44.7484, 44.7962, 44.8436, 44.8918, 44.9395, 44.987,
45.0352, 45.083, 45.1305, 45.1787, 45.2266, 45.2742, 45.3223, 45.3703,
45.4179, 45.4652, 45.5131, 45.5608, 45.6081, 45.6561, 45.7038, 45.7512,
45.7992, 45.8469, 45.8944, 45.9424, 45.9902, 46.0376, 46.0857, 46.1335,
46.181, 46.2292, 46.277, 46.3245, 46.3727, 46.4206, 46.4682, 46.5164,
};
void TestAttentionMultipleTokens() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
AttentionMultipleTokensAttentionGoldens.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-3,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
int qkv_dim = 64;
int kv_seq_len = 34;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 31; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
int attention_window_size = 32;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, attention_window_size,
num_kv_heads, num_heads, num_tokens, last_pos,
att_cap, layer_idx, layers_total, qbatch_size,
attention_impl);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
std::vector<float> att_out_golden_test_local = {
39.3051, 39.3556, 39.4062, 39.4571, 39.5075, 39.5582, 39.6091, 39.6596,
39.7103, 39.7612, 39.8118, 39.8626, 39.913, 39.9636, 40.0144, 40.0649,
40.1155, 40.1664, 40.2169, 40.2676, 40.3185, 40.369, 40.4198, 40.4707,
40.5213, 40.572, 40.6225, 40.6731, 40.724, 40.7744, 40.8251, 40.876,
40.9265, 40.9772, 41.0281, 41.0787, 41.1295, 41.1799, 41.2305, 41.2813,
41.3318, 41.3824, 41.4333, 41.4838, 41.5345, 41.5854, 41.6359, 41.6867,
41.7376, 41.7882, 41.839, 41.8894, 41.94, 41.9908, 42.0413, 42.092,
42.1429, 42.1934, 42.2441, 42.295, 42.3456, 42.3964, 42.4468, 42.4974,
39.1614, 39.2113, 39.2613, 39.3114, 39.3613, 39.4113, 39.4616, 39.5115,
39.5616, 39.6118, 39.6618, 39.7119, 39.7617, 39.8117, 39.8618, 39.9117,
39.9617, 40.0119, 40.0618, 40.1118, 40.1621, 40.212, 40.2621, 40.3124,
40.3623, 40.4125, 40.4623, 40.5123, 40.5625, 40.6123, 40.6624, 40.7126,
40.7625, 40.8126, 40.8629, 40.9128, 40.9629, 41.0127, 41.0627, 41.1129,
41.1627, 41.2127, 41.2629, 41.3128, 41.3629, 41.4131, 41.463, 41.5131,
41.5634, 41.6134, 41.6635, 41.7133, 41.7634, 41.8135, 41.8634, 41.9134,
41.9637, 42.0135, 42.0636, 42.1139, 42.1638, 42.214, 42.2637, 42.3137,
43.8459, 43.895, 43.9437, 43.9921, 44.0411, 44.0898, 44.1383, 44.1874,
44.2361, 44.2846, 44.3337, 44.3825, 44.4311, 44.4802, 44.529, 44.5776,
44.6258, 44.6747, 44.7233, 44.7716, 44.8205, 44.8692, 44.9175, 44.9665,
45.0151, 45.0635, 45.1125, 45.1612, 45.2096, 45.2586, 45.3074, 45.3558,
45.4049, 45.4537, 45.5021, 45.5513, 45.6001, 45.6486, 45.6977, 45.7466,
45.7951, 45.8434, 45.8923, 45.9409, 45.9891, 46.0381, 46.0867, 46.135,
46.184, 46.2327, 46.281, 46.33, 46.3787, 46.4271, 46.4762, 46.5249,
46.5733, 46.6224, 46.6712, 46.7197, 46.7688, 46.8176, 46.8661, 46.9153,
43.7538, 43.8026, 43.851, 43.8992, 43.948, 43.9964, 44.0446, 44.0934,
44.142, 44.1902, 44.239, 44.2876, 44.3358, 44.3847, 44.4333, 44.4816,
44.5296, 44.5782, 44.6266, 44.6746, 44.7232, 44.7716, 44.8197, 44.8684,
44.9168, 44.9649, 45.0136, 45.0621, 45.1102, 45.159, 45.2075, 45.2557,
45.3045, 45.353, 45.4012, 45.4501, 45.4986, 45.5469, 45.5958, 45.6444,
45.6927, 45.7406, 45.7893, 45.8376, 45.8856, 45.9343, 45.9827, 46.0307,
46.0794, 46.1278, 46.1759, 46.2247, 46.2731, 46.3213, 46.3701, 46.4185,
46.4667, 46.5155, 46.564, 46.6123, 46.6611, 46.7097, 46.7579, 46.8068,
48.7531, 48.8438, 48.9348, 49.0262, 49.1169, 49.208, 49.2995, 49.3903,
49.4815, 49.573, 49.6639, 49.7552, 49.8458, 49.9368, 50.0281, 50.1188,
50.2099, 50.3013, 50.3921, 50.4832, 50.5747, 50.6656, 50.7568, 50.8484,
50.9393, 51.0306, 51.1213, 51.2123, 51.3037, 51.3944, 51.4855, 51.577,
51.6678, 51.759, 51.8505, 51.9414, 52.0327, 52.1233, 52.2143, 52.3056,
52.3963, 52.4874, 52.5788, 52.6696, 52.7607, 52.8522, 52.9431, 53.0343,
53.1259, 53.2168, 53.3081, 53.3988, 53.4898, 53.5812, 53.6719, 53.763,
53.8545, 53.9453, 54.0365, 54.128, 54.2189, 54.3102, 54.4008, 54.4918,
48.4943, 48.5838, 48.6737, 48.7639, 48.8535, 48.9435, 49.0338, 49.1235,
49.2135, 49.3039, 49.3937, 49.4838, 49.5732, 49.6631, 49.7533, 49.8428,
49.9328, 50.023, 50.1127, 50.2027, 50.293, 50.3827, 50.4728, 50.5632,
50.653, 50.7432, 50.8327, 50.9226, 51.0128, 51.1024, 51.1924, 51.2827,
51.3724, 51.4624, 51.5528, 51.6425, 51.7327, 51.8221, 51.912, 52.0022,
52.0917, 52.1817, 52.2719, 52.3616, 52.4516, 52.5419, 52.6316, 52.7217,
52.8121, 52.9019, 52.9921, 53.0816, 53.1715, 53.2617, 53.3513, 53.4413,
53.5316, 53.6212, 53.7113, 53.8017, 53.8914, 53.9815, 54.071, 54.1609,
57.7208, 57.8084, 57.8954, 57.9818, 58.0694, 58.1564, 58.2429, 58.3306,
58.4177, 58.5043, 58.5921, 58.6793, 58.7659, 58.8537, 58.941, 59.0277,
59.1137, 59.2011, 59.2878, 59.374, 59.4614, 59.5482, 59.6345, 59.722,
59.8089, 59.8952, 59.9827, 60.0697, 60.1561, 60.2437, 60.3308, 60.4172,
60.505, 60.5921, 60.6786, 60.7664, 60.8536, 60.9402, 61.0281, 61.1153,
61.202, 61.2881, 61.3755, 61.4622, 61.5483, 61.6358, 61.7226, 61.8088,
61.8963, 61.9832, 62.0695, 62.1571, 62.244, 62.3304, 62.4181, 62.5051,
62.5916, 62.6793, 62.7664, 62.853, 62.9407, 63.0279, 63.1146, 63.2024,
57.5554, 57.6426, 57.729, 57.815, 57.9021, 57.9887, 58.0747, 58.162,
58.2486, 58.3347, 58.422, 58.5087, 58.5949, 58.6823, 58.7691, 58.8553,
58.9409, 59.0278, 59.114, 59.1997, 59.2867, 59.373, 59.4588, 59.5458,
59.6323, 59.7181, 59.8052, 59.8917, 59.9776, 60.0648, 60.1514, 60.2374,
60.3246, 60.4113, 60.4974, 60.5847, 60.6714, 60.7576, 60.8449, 60.9317,
61.018, 61.1036, 61.1905, 61.2767, 61.3624, 61.4494, 61.5357, 61.6215,
61.7085, 61.7949, 61.8808, 61.9679, 62.0544, 62.1403, 62.2275, 62.3141,
62.4001, 62.4873, 62.574, 62.66, 62.7474, 62.8341, 62.9202, 63.0076,
39.3678, 39.4186, 39.4696, 39.5207, 39.5715, 39.6225, 39.6737, 39.7246,
39.7756, 39.8268, 39.8777, 39.9288, 39.9796, 40.0305, 40.0816, 40.1324,
40.1834, 40.2346, 40.2854, 40.3364, 40.3876, 40.4385, 40.4896, 40.5408,
40.5917, 40.6428, 40.6936, 40.7446, 40.7957, 40.8466, 40.8975, 40.9487,
40.9996, 41.0506, 41.1019, 41.1528, 41.2038, 41.2546, 41.3055, 41.3567,
41.4075, 41.4584, 41.5096, 41.5605, 41.6115, 41.6627, 41.7136, 41.7646,
41.8159, 41.8668, 41.9179, 41.9687, 42.0196, 42.0708, 42.1216, 42.1726,
42.2238, 42.2746, 42.3256, 42.3769, 42.4278, 42.4789, 42.5296, 42.5806,
39.2228, 39.2729, 39.3232, 39.3737, 39.4239, 39.4743, 39.5248, 39.575,
39.6254, 39.676, 39.7263, 39.7767, 39.8268, 39.8771, 39.9276, 39.9778,
40.0281, 40.0786, 40.1288, 40.1792, 40.2298, 40.28, 40.3304, 40.381,
40.4313, 40.4818, 40.5319, 40.5822, 40.6327, 40.6829, 40.7333, 40.7838,
40.834, 40.8844, 40.935, 40.9853, 41.0357, 41.0858, 41.1361, 41.1866,
41.2368, 41.2871, 41.3376, 41.3878, 41.4382, 41.4888, 41.539, 41.5894,
41.64, 41.6903, 41.7408, 41.7909, 41.8412, 41.8917, 41.9419, 41.9922,
42.0428, 42.093, 42.1434, 42.194, 42.2442, 42.2947, 42.3448, 42.3951,
43.9435, 43.9928, 44.0418, 44.0905, 44.1399, 44.1889, 44.2376, 44.287,
44.3361, 44.3849, 44.4343, 44.4834, 44.5322, 44.5817, 44.6308, 44.6797,
44.7283, 44.7774, 44.8263, 44.8749, 44.9241, 44.9731, 45.0217, 45.071,
45.12, 45.1686, 45.2179, 45.2669, 45.3156, 45.365, 45.414, 45.4628,
45.5122, 45.5613, 45.61, 45.6595, 45.7086, 45.7574, 45.8068, 45.856,
45.9048, 45.9534, 46.0026, 46.0515, 46.1001, 46.1493, 46.1982, 46.2469,
46.2961, 46.3451, 46.3938, 46.4431, 46.4921, 46.5408, 46.5901, 46.6392,
46.6879, 46.7373, 46.7864, 46.8352, 46.8846, 46.9337, 46.9825, 47.032,
43.8506, 43.8996, 43.9484, 43.9968, 44.0459, 44.0947, 44.1432, 44.1923,
44.2411, 44.2896, 44.3388, 44.3876, 44.4362, 44.4854, 44.5343, 44.5829,
44.6312, 44.6801, 44.7287, 44.7771, 44.826, 44.8747, 44.9231, 44.9721,
45.0208, 45.0692, 45.1182, 45.167, 45.2154, 45.2645, 45.3133, 45.3617,
45.4109, 45.4597, 45.5082, 45.5574, 45.6062, 45.6548, 45.704, 45.7529,
45.8015, 45.8498, 45.8987, 45.9473, 45.9957, 46.0446, 46.0933, 46.1416,
46.1906, 46.2394, 46.2878, 46.3368, 46.3856, 46.434, 46.4831, 46.5319,
46.5803, 46.6295, 46.6783, 46.7268, 46.776, 46.8248, 46.8734, 46.9226,
48.8777, 48.969, 49.0607, 49.1527, 49.2441, 49.3358, 49.4279, 49.5194,
49.6112, 49.7034, 49.7949, 49.8868, 49.9781, 50.0697, 50.1617, 50.2531,
50.3448, 50.4368, 50.5283, 50.62, 50.7122, 50.8037, 50.8956, 50.9878,
51.0794, 51.1713, 51.2626, 51.3543, 51.4463, 51.5377, 51.6294, 51.7215,
51.813, 51.9048, 51.997, 52.0885, 52.1805, 52.2717, 52.3633, 52.4553,
52.5467, 52.6384, 52.7305, 52.8219, 52.9137, 53.0058, 53.0973, 53.1892,
53.2814, 53.373, 53.4649, 53.5562, 53.6479, 53.7399, 53.8313, 53.923,
54.0152, 54.1066, 54.1984, 54.2906, 54.3821, 54.4741, 54.5653, 54.6569,
48.6164, 48.7066, 48.7971, 48.888, 48.9782, 49.0688, 49.1597, 49.25,
49.3407, 49.4317, 49.5221, 49.6129, 49.703, 49.7934, 49.8843, 49.9745,
50.065, 50.1559, 50.2462, 50.3368, 50.4278, 50.5181, 50.6089, 50.6999,
50.7903, 50.8811, 50.9713, 51.0618, 51.1527, 51.2429, 51.3335, 51.4244,
51.5147, 51.6054, 51.6964, 51.7868, 51.8776, 51.9677, 52.0581, 52.149,
52.2392, 52.3297, 52.4206, 52.5109, 52.6015, 52.6925, 52.7828, 52.8736,
52.9646, 53.055, 53.1458, 53.236, 53.3265, 53.4174, 53.5076, 53.5982,
53.6891, 53.7794, 53.8701, 53.9611, 54.0515, 54.1423, 54.2324, 54.3228,
57.914, 58.0021, 58.0897, 58.1767, 58.265, 58.3526, 58.4397, 58.528,
58.6157, 58.7028, 58.7912, 58.879, 58.9662, 59.0547, 59.1426, 59.2299,
59.3165, 59.4045, 59.4918, 59.5786, 59.6666, 59.754, 59.8408, 59.9289,
60.0165, 60.1033, 60.1915, 60.2791, 60.3661, 60.4544, 60.542, 60.629,
60.7174, 60.8051, 60.8922, 60.9806, 61.0684, 61.1556, 61.2441, 61.332,
61.4193, 61.5059, 61.5939, 61.6812, 61.768, 61.856, 61.9434, 62.0302,
62.1183, 62.2059, 62.2927, 62.3809, 62.4685, 62.5555, 62.6437, 62.7314,
62.8184, 62.9068, 62.9945, 63.0816, 63.17, 63.2578, 63.345, 63.4335,
57.7471, 57.8348, 57.9219, 58.0084, 58.0962, 58.1834, 58.27, 58.3578,
58.4451, 58.5317, 58.6197, 58.707, 58.7937, 58.8817, 58.9691, 59.0559,
59.1421, 59.2296, 59.3165, 59.4028, 59.4903, 59.5773, 59.6636, 59.7512,
59.8383, 59.9247, 60.0124, 60.0995, 60.186, 60.2738, 60.361, 60.4476,
60.5354, 60.6227, 60.7093, 60.7973, 60.8846, 60.9713, 61.0593, 61.1467,
61.2335, 61.3197, 61.4072, 61.4941, 61.5804, 61.6679, 61.7549, 61.8412,
61.9289, 62.0159, 62.1023, 62.19, 62.2772, 62.3636, 62.4514, 62.5386,
62.6252, 62.7131, 62.8003, 62.887, 62.9749, 63.0622, 63.1489, 63.237};
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
att_out_golden_test_local.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-3,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
void TestAttentionMultipleTokensBF16() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.vit_Q);
FillMatPtrT(test_env.activations->attention.vit_K);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
AttentionMultipleTokensAttentionGoldens.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-1,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(TiledAttentionTest);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
TestLocalAttentionForAllHeadsTokensAndBatch);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16);
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
HWY_AFTER_TEST();
} // namespace gcpp
#endif

View File

@ -1026,6 +1026,450 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
}
HWY_DASSERT(size == i);
}
template <int32_t N, typename DF, class VF = hn::Vec<DF>>
static HWY_INLINE void StoreUpTo8Times2(DF df, MatPtrT<float>& out,
size_t start_col, VF out0_0, VF out0_1,
VF out1_0, VF out1_1, VF out2_0,
VF out2_1, VF out3_0, VF out3_1,
VF out4_0, VF out4_1, VF out5_0,
VF out5_1, VF out6_0, VF out6_1,
VF out7_0, VF out7_1) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
hn::Store(out0_0, df, out.Row(0) + start_col);
hn::Store(out0_1, df, out.Row(0) + start_col + NF);
if constexpr (N >= 2) {
hn::Store(out1_0, df, out.Row(1) + start_col);
hn::Store(out1_1, df, out.Row(1) + start_col + NF);
}
if constexpr (N >= 3) {
hn::Store(out2_0, df, out.Row(2) + start_col);
hn::Store(out2_1, df, out.Row(2) + start_col + NF);
}
if constexpr (N >= 4) {
hn::Store(out3_0, df, out.Row(3) + start_col);
hn::Store(out3_1, df, out.Row(3) + start_col + NF);
}
if constexpr (N >= 5) {
hn::Store(out4_0, df, out.Row(4) + start_col);
hn::Store(out4_1, df, out.Row(4) + start_col + NF);
}
if constexpr (N >= 6) {
hn::Store(out5_0, df, out.Row(5) + start_col);
hn::Store(out5_1, df, out.Row(5) + start_col + NF);
}
if constexpr (N >= 7) {
hn::Store(out6_0, df, out.Row(6) + start_col);
hn::Store(out6_1, df, out.Row(6) + start_col + NF);
}
if constexpr (N >= 8) {
hn::Store(out7_0, df, out.Row(7) + start_col);
hn::Store(out7_1, df, out.Row(7) + start_col + NF);
}
}
template <int N, typename DF, class VF = hn::Vec<DF>>
static HWY_INLINE void LoadAndMulUpTo8Times2(
DF df, MatPtrT<float>& out, size_t column, const float* HWY_RESTRICT scales,
VF& out0_0, VF& out0_1, VF& out1_0, VF& out1_1, VF& out2_0, VF& out2_1,
VF& out3_0, VF& out3_1, VF& out4_0, VF& out4_1, VF& out5_0, VF& out5_1,
VF& out6_0, VF& out6_1, VF& out7_0, VF& out7_1) {
namespace hn = hwy::HWY_NAMESPACE;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
out0_0 = hn::Load(df, out.Row(0) + column);
out0_0 = hn::Mul(out0_0, hn::Set(df, scales[0]));
out0_1 = hn::Load(df, out.Row(0) + column + NF);
out0_1 = hn::Mul(out0_1, hn::Set(df, scales[0]));
if constexpr (N >= 2) {
out1_0 = hn::Load(df, out.Row(1) + column);
out1_0 = hn::Mul(out1_0, hn::Set(df, scales[1]));
out1_1 = hn::Load(df, out.Row(1) + column + NF);
out1_1 = hn::Mul(out1_1, hn::Set(df, scales[1]));
}
if constexpr (N >= 3) {
out2_0 = hn::Load(df, out.Row(2) + column);
out2_0 = hn::Mul(out2_0, hn::Set(df, scales[2]));
out2_1 = hn::Load(df, out.Row(2) + column + NF);
out2_1 = hn::Mul(out2_1, hn::Set(df, scales[2]));
}
if constexpr (N >= 4) {
out3_0 = hn::Load(df, out.Row(3) + column);
out3_0 = hn::Mul(out3_0, hn::Set(df, scales[3]));
out3_1 = hn::Load(df, out.Row(3) + column + NF);
out3_1 = hn::Mul(out3_1, hn::Set(df, scales[3]));
}
if constexpr (N >= 5) {
out4_0 = hn::Load(df, out.Row(4) + column);
out4_0 = hn::Mul(out4_0, hn::Set(df, scales[4]));
out4_1 = hn::Load(df, out.Row(4) + column + NF);
out4_1 = hn::Mul(out4_1, hn::Set(df, scales[4]));
}
if constexpr (N >= 6) {
out5_0 = hn::Load(df, out.Row(5) + column);
out5_0 = hn::Mul(out5_0, hn::Set(df, scales[5]));
out5_1 = hn::Load(df, out.Row(5) + column + NF);
out5_1 = hn::Mul(out5_1, hn::Set(df, scales[5]));
}
if constexpr (N >= 7) {
out6_0 = hn::Load(df, out.Row(6) + column);
out6_0 = hn::Mul(out6_0, hn::Set(df, scales[6]));
out6_1 = hn::Load(df, out.Row(6) + column + NF);
out6_1 = hn::Mul(out6_1, hn::Set(df, scales[6]));
}
if constexpr (N >= 8) {
out7_0 = hn::Load(df, out.Row(7) + column);
out7_0 = hn::Mul(out7_0, hn::Set(df, scales[7]));
out7_1 = hn::Load(df, out.Row(7) + column + NF);
out7_1 = hn::Mul(out7_1, hn::Set(df, scales[7]));
}
}
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8(
DF df, const float* HWY_RESTRICT scales, const VF& c0_p0, const VF& c0_p1,
const VF& c1_p0, const VF& c1_p1, const VF& c2_p0, const VF& c2_p1,
const VF& c3_p0, const VF& c3_p1, const VF& c4_p0, const VF& c4_p1,
const VF& c5_p0, const VF& c5_p1, const VF& c6_p0, const VF& c6_p1,
const VF& c7_p0, const VF& c7_p1, const VType* HWY_RESTRICT v_tile,
MatPtrT<float>& out) {
static_assert(N <= 8);
namespace hn = hwy::HWY_NAMESPACE;
const size_t qkv_dim = out.Cols();
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
PackedSpan<const VType> v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF);
size_t i = 0;
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
HWY_ALIGN float consts_buffer[kMaxLanes * N * 2];
hn::Store(c0_p0, df, consts_buffer);
hn::Store(c0_p1, df, consts_buffer + kMaxLanes);
if constexpr (N >= 2) {
hn::Store(c1_p0, df, consts_buffer + 2 * kMaxLanes);
hn::Store(c1_p1, df, consts_buffer + 3 * kMaxLanes);
}
if constexpr (N >= 3) {
hn::Store(c2_p0, df, consts_buffer + 4 * kMaxLanes);
hn::Store(c2_p1, df, consts_buffer + 5 * kMaxLanes);
}
if constexpr (N >= 4) {
hn::Store(c3_p0, df, consts_buffer + 6 * kMaxLanes);
hn::Store(c3_p1, df, consts_buffer + 7 * kMaxLanes);
}
if constexpr (N >= 5) {
hn::Store(c4_p0, df, consts_buffer + 8 * kMaxLanes);
hn::Store(c4_p1, df, consts_buffer + 9 * kMaxLanes);
}
if constexpr (N >= 6) {
hn::Store(c5_p0, df, consts_buffer + 10 * kMaxLanes);
hn::Store(c5_p1, df, consts_buffer + 11 * kMaxLanes);
}
if constexpr (N >= 7) {
hn::Store(c6_p0, df, consts_buffer + 12 * kMaxLanes);
hn::Store(c6_p1, df, consts_buffer + 13 * kMaxLanes);
}
if constexpr (N >= 8) {
hn::Store(c7_p0, df, consts_buffer + 14 * kMaxLanes);
hn::Store(c7_p1, df, consts_buffer + 15 * kMaxLanes);
}
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
while (i + NF * 2 <= qkv_dim) {
VF out0_0, out1_0, out2_0, out3_0, out4_0, out5_0, out6_0, out7_0;
VF out0_1, out1_1, out2_1, out3_1, out4_1, out5_1, out6_1, out7_1;
LoadAndMulUpTo8Times2<N>(df, out, i, scales, out0_0, out0_1, out1_0, out1_1,
out2_0, out2_1, out3_0, out3_1, out4_0, out4_1,
out5_0, out5_1, out6_0, out6_1, out7_0, out7_1);
for (int lane = 0; lane < NF; ++lane) {
VF xI1, xI2;
Decompress2(df, v_span, qkv_dim * lane + i, xI1, xI2);
out0_0 = hn::MulAdd(xI1, hn::Set(df, consts_buffer[lane + 0 * kMaxLanes]),
out0_0);
out0_1 = hn::MulAdd(xI2, hn::Set(df, consts_buffer[lane + 0 * kMaxLanes]),
out0_1);
if constexpr (N >= 2) {
out1_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 2 * kMaxLanes]), out1_0);
out1_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 2 * kMaxLanes]), out1_1);
}
if constexpr (N >= 3) {
out2_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 4 * kMaxLanes]), out2_0);
out2_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 4 * kMaxLanes]), out2_1);
}
if constexpr (N >= 4) {
out3_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 6 * kMaxLanes]), out3_0);
out3_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 6 * kMaxLanes]), out3_1);
}
if constexpr (N >= 5) {
out4_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 8 * kMaxLanes]), out4_0);
out4_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 8 * kMaxLanes]), out4_1);
}
if constexpr (N >= 6) {
out5_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 10 * kMaxLanes]), out5_0);
out5_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 10 * kMaxLanes]), out5_1);
}
if constexpr (N >= 7) {
out6_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 12 * kMaxLanes]), out6_0);
out6_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 12 * kMaxLanes]), out6_1);
}
if constexpr (N >= 8) {
out7_0 = hn::MulAdd(
xI1, hn::Set(df, consts_buffer[lane + 14 * kMaxLanes]), out7_0);
out7_1 = hn::MulAdd(
xI2, hn::Set(df, consts_buffer[lane + 14 * kMaxLanes]), out7_1);
}
VF xI3, xI4;
Decompress2(df, v_span, qkv_dim * (NF + lane) + i, xI3, xI4);
out0_0 = hn::MulAdd(xI3, hn::Set(df, consts_buffer[lane + 1 * kMaxLanes]),
out0_0);
out0_1 = hn::MulAdd(xI4, hn::Set(df, consts_buffer[lane + 1 * kMaxLanes]),
out0_1);
if constexpr (N >= 2) {
out1_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 3 * kMaxLanes]), out1_0);
out1_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 3 * kMaxLanes]), out1_1);
}
if constexpr (N >= 3) {
out2_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 5 * kMaxLanes]), out2_0);
out2_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 5 * kMaxLanes]), out2_1);
}
if constexpr (N >= 4) {
out3_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 7 * kMaxLanes]), out3_0);
out3_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 7 * kMaxLanes]), out3_1);
}
if constexpr (N >= 5) {
out4_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 9 * kMaxLanes]), out4_0);
out4_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 9 * kMaxLanes]), out4_1);
}
if constexpr (N >= 6) {
out5_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 11 * kMaxLanes]), out5_0);
out5_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 11 * kMaxLanes]), out5_1);
}
if constexpr (N >= 7) {
out6_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 13 * kMaxLanes]), out6_0);
out6_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 13 * kMaxLanes]), out6_1);
}
if constexpr (N >= 8) {
out7_0 = hn::MulAdd(
xI3, hn::Set(df, consts_buffer[lane + 15 * kMaxLanes]), out7_0);
out7_1 = hn::MulAdd(
xI4, hn::Set(df, consts_buffer[lane + 15 * kMaxLanes]), out7_1);
}
}
StoreUpTo8Times2<N>(df, out, i, out0_0, out0_1, out1_0, out1_1, out2_0,
out2_1, out3_0, out3_1, out4_0, out4_1, out5_0, out5_1,
out6_0, out6_1, out7_0, out7_1);
i += 2 * NF;
}
HWY_DASSERT(qkv_dim == i);
}
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
DF df, const float* HWY_RESTRICT scales, VF c0_p0, VF c0_p1, VF c1_p0,
VF c1_p1, VF c2_p0, VF c2_p1, VF c3_p0, VF c3_p1, VF c4_p0, VF c4_p1,
VF c5_p0, VF c5_p1, VF c6_p0, VF c6_p1, VF c7_p0, VF c7_p1,
VType* HWY_RESTRICT v_tile, MatPtrT<float>& out) {
static_assert(N <= 8);
namespace hn = hwy::HWY_NAMESPACE;
const size_t qkv_dim = out.Cols();
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
constexpr size_t kMaxLanes = hn::MaxLanes(df);
using DBF = hn::ScalableTag<BF16>;
const DBF dbf;
using VBF = hn::Vec<DBF>;
PackedSpan<const VType> v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF);
HWY_ALIGN BF16 cs[N * kMaxLanes * 2];
PackedSpan<BF16> cs_span = MakeSpan(cs, N * kMaxLanes * 2);
float* cs_as_float = HWY_RCAST_ALIGNED(float*, cs);
Compress2(df, c0_p0, c0_p1, cs_span, 0);
if constexpr (N >= 2) {
Compress2(df, c1_p0, c1_p1, cs_span, kMaxLanes * 2);
}
if constexpr (N >= 3) {
Compress2(df, c2_p0, c2_p1, cs_span, 2 * kMaxLanes * 2);
}
if constexpr (N >= 4) {
Compress2(df, c3_p0, c3_p1, cs_span, 3 * kMaxLanes * 2);
}
if constexpr (N >= 5) {
Compress2(df, c4_p0, c4_p1, cs_span, 4 * kMaxLanes * 2);
}
if constexpr (N >= 6) {
Compress2(df, c5_p0, c5_p1, cs_span, 5 * kMaxLanes * 2);
}
if constexpr (N >= 7) {
Compress2(df, c6_p0, c6_p1, cs_span, 6 * kMaxLanes * 2);
}
if constexpr (N >= 8) {
Compress2(df, c7_p0, c7_p1, cs_span, 7 * kMaxLanes * 2);
}
VF zero = hn::Zero(df);
size_t i = 0;
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
while (i + NF * 2 <= qkv_dim) {
VF out0_0, out1_0, out2_0, out3_0;
VF out0_1, out1_1, out2_1, out3_1;
VF out4_0, out5_0, out6_0, out7_0;
VF out4_1, out5_1, out6_1, out7_1;
VF helper_out0_0 = hn::Zero(df), helper_out0_1 = hn::Zero(df),
helper_out1_0 = hn::Zero(df), helper_out1_1 = hn::Zero(df),
helper_out2_0 = hn::Zero(df), helper_out2_1 = hn::Zero(df),
helper_out3_0 = hn::Zero(df), helper_out3_1 = hn::Zero(df),
helper_out4_0 = hn::Zero(df), helper_out4_1 = hn::Zero(df),
helper_out5_0 = hn::Zero(df), helper_out5_1 = hn::Zero(df),
helper_out6_0 = hn::Zero(df), helper_out6_1 = hn::Zero(df),
helper_out7_0 = hn::Zero(df), helper_out7_1 = hn::Zero(df);
LoadAndMulUpTo8Times2<N>(df, out, i, scales, out0_0, out0_1, out1_0, out1_1,
out2_0, out2_1, out3_0, out3_1, out4_0, out4_1,
out5_0, out5_1, out6_0, out6_1, out7_0, out7_1);
for (int lane = 0; lane < NF; ++lane) {
VBF xI, xI2;
Decompress2(dbf, v_span, 2 * qkv_dim * lane + i * 2, xI, xI2);
// Set pair of c scales for 2 value vectors
out0_0 = hn::ReorderWidenMulAccumulate(
df, xI, hn::BitCast(dbf, hn::Set(df, cs_as_float[lane])), out0_0,
helper_out0_0);
out0_1 = hn::ReorderWidenMulAccumulate(
df, xI2, hn::BitCast(dbf, hn::Set(df, cs_as_float[lane])), out0_1,
helper_out0_1);
if constexpr (N >= 2) {
out1_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + kMaxLanes])),
out1_0, helper_out1_0);
out1_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + kMaxLanes])),
out1_1, helper_out1_1);
}
if constexpr (N >= 3) {
out2_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 2 * kMaxLanes])),
out2_0, helper_out2_0);
out2_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 2 * kMaxLanes])),
out2_1, helper_out2_1);
}
if constexpr (N >= 4) {
out3_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 3 * kMaxLanes])),
out3_0, helper_out3_0);
out3_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 3 * kMaxLanes])),
out3_1, helper_out3_1);
}
if constexpr (N >= 5) {
out4_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 4 * kMaxLanes])),
out4_0, helper_out4_0);
out4_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 4 * kMaxLanes])),
out4_1, helper_out4_1);
}
if constexpr (N >= 6) {
out5_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 5 * kMaxLanes])),
out5_0, helper_out5_0);
out5_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 5 * kMaxLanes])),
out5_1, helper_out5_1);
}
if constexpr (N >= 7) {
out6_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 6 * kMaxLanes])),
out6_0, helper_out6_0);
out6_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 6 * kMaxLanes])),
out6_1, helper_out6_1);
}
if constexpr (N >= 8) {
out7_0 = hn::ReorderWidenMulAccumulate(
df, xI,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 7 * kMaxLanes])),
out7_0, helper_out7_0);
out7_1 = hn::ReorderWidenMulAccumulate(
df, xI2,
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 7 * kMaxLanes])),
out7_1, helper_out7_1);
}
}
#if HWY_NATIVE_DOT_BF16 == 0
out0_0 = hn::Add(out0_0, helper_out0_0);
out0_1 = hn::Add(out0_1, helper_out0_1);
if constexpr (N >= 2) {
out1_0 = hn::Add(out1_0, helper_out1_0);
out1_1 = hn::Add(out1_1, helper_out1_1);
}
if constexpr (N >= 3) {
out2_0 = hn::Add(out2_0, helper_out2_0);
out2_1 = hn::Add(out2_1, helper_out2_1);
}
if constexpr (N >= 4) {
out3_0 = hn::Add(out3_0, helper_out3_0);
out3_1 = hn::Add(out3_1, helper_out3_1);
}
if constexpr (N >= 5) {
out4_0 = hn::Add(out4_0, helper_out4_0);
out4_1 = hn::Add(out4_1, helper_out4_1);
}
if constexpr (N >= 6) {
out5_0 = hn::Add(out5_0, helper_out5_0);
out5_1 = hn::Add(out5_1, helper_out5_1);
}
if constexpr (N >= 7) {
out6_0 = hn::Add(out6_0, helper_out6_0);
out6_1 = hn::Add(out6_1, helper_out6_1);
}
if constexpr (N >= 8) {
out7_0 = hn::Add(out7_0, helper_out7_0);
out7_1 = hn::Add(out7_1, helper_out7_1);
}
#endif
StoreUpTo8Times2<N>(df, out, i, out0_0, out0_1, out1_0, out1_1, out2_0,
out2_1, out3_0, out3_1, out4_0, out4_1, out5_0, out5_1,
out6_0, out6_1, out7_0, out7_1);
i += 2 * NF;
}
HWY_DASSERT(qkv_dim == i);
}
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
// corresponding values in c0 and adds them to the NF rows of out.