mirror of https://github.com/google/gemma.cpp.git
Implementation of tiled attention with bf16 and circular buffers which reduces memory requirements by 4x on longer context on gemma models.
It also supports better parallelism for small batch sizes / small models. It also is able to utilize VDPBF16PS for nice 2x improvement on avx512 PiperOrigin-RevId: 874517319
This commit is contained in:
parent
463a3682be
commit
df162ead7c
|
|
@ -652,10 +652,12 @@ cc_library(
|
|||
name = "gemma_lib",
|
||||
srcs = [
|
||||
"gemma/gemma.cc",
|
||||
"gemma/tiled_attention.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/gemma.h",
|
||||
"gemma/tiled_attention.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
|
|
|
|||
|
|
@ -93,6 +93,8 @@ set(SOURCES
|
|||
gemma/model_store.h
|
||||
gemma/tensor_info.cc
|
||||
gemma/tensor_info.h
|
||||
gemma/tiled_attention.cc
|
||||
gemma/tiled_attention.h
|
||||
gemma/tokenizer.cc
|
||||
gemma/tokenizer.h
|
||||
gemma/vit.cc
|
||||
|
|
@ -171,20 +173,20 @@ install(TARGETS libgemma DESTINATION lib)
|
|||
if(BUILD_GEMMA_DLL)
|
||||
add_library(gemma_shared SHARED ${SOURCES})
|
||||
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
|
||||
set_target_properties(gemma_shared PROPERTIES
|
||||
set_target_properties(gemma_shared PROPERTIES
|
||||
PREFIX ""
|
||||
OUTPUT_NAME "gemma"
|
||||
)
|
||||
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(gemma_shared PUBLIC ./)
|
||||
target_link_libraries(gemma_shared PRIVATE
|
||||
target_link_libraries(gemma_shared PRIVATE
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
|
||||
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
|
||||
)
|
||||
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
|
||||
target_compile_definitions(gemma_shared
|
||||
PRIVATE
|
||||
target_compile_definitions(gemma_shared
|
||||
PRIVATE
|
||||
GEMMA_EXPORTS
|
||||
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -153,6 +153,14 @@ struct AttentionActivations {
|
|||
// Accumulation of attention outputs over heads
|
||||
MatStorageT<BF16> att_sums;
|
||||
|
||||
MatStorageT<float> k_tile_vec;
|
||||
MatStorageT<float> v_tile_vec;
|
||||
std::vector<MatStorageT<float>> sub_task_att_out;
|
||||
std::vector<AlignedFloatVector>
|
||||
sub_task_exp_denominator_sums;
|
||||
std::vector<AlignedFloatVector>
|
||||
sub_task_max_logits;
|
||||
|
||||
// Rope
|
||||
MatStorageT<float> inv_timescale;
|
||||
MatStorageT<float> inv_timescale_global;
|
||||
|
|
@ -244,6 +252,16 @@ struct AttentionActivationsPtrs {
|
|||
// Accumulation of attention outputs over heads, size batch_size x
|
||||
// model_dim.
|
||||
MatPtrT<BF16> att_sums;
|
||||
// Stores intermediate results of computing QKV,
|
||||
// [qbatch * kv_heads , k_tile_size * qkv_dim]
|
||||
MatPtrT<float> k_tile_vec;
|
||||
MatPtrT<float> v_tile_vec;
|
||||
// Used by TiledFlashAttention to store intermediate results.
|
||||
std::vector<MatStorageT<float>>* sub_task_att_out;
|
||||
std::vector<AlignedFloatVector>*
|
||||
sub_task_exp_denominator_sums;
|
||||
std::vector<AlignedFloatVector>*
|
||||
sub_task_max_logits;
|
||||
// Inverse timescales for RoPE computation.
|
||||
MatPtrT<float> inv_timescale;
|
||||
// Inverse timescales for global RoPE computation.
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ static inline bool EnumValid(LayerAttentionType type) {
|
|||
enum class AttentionImpl {
|
||||
kOld,
|
||||
kFlash,
|
||||
kFlashTransposedQs,
|
||||
kFlashTransposedQsBF16,
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
|
|
@ -108,6 +110,8 @@ static inline int AttentionImplToFlags(AttentionImpl impl,
|
|||
case AttentionImpl::kOld:
|
||||
return kAttentionUseOld;
|
||||
case AttentionImpl::kFlash:
|
||||
case AttentionImpl::kFlashTransposedQs:
|
||||
case AttentionImpl::kFlashTransposedQsBF16:
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -921,6 +921,620 @@ Tile4FlashState TileFlashAttention4(
|
|||
return state;
|
||||
}
|
||||
|
||||
template <int kNumQueries, typename Q_T, class DQ_T, class VQ_T = hn::Vec<DQ_T>,
|
||||
typename T>
|
||||
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth(
|
||||
DQ_T df, const Q_T* HWY_RESTRICT q, const Q_T* HWY_RESTRICT q2,
|
||||
const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VQ_T& sum0_p0,
|
||||
VQ_T& sum0_p1, VQ_T& sum1_p0, VQ_T& sum1_p1, VQ_T& sum2_p0, VQ_T& sum2_p1,
|
||||
VQ_T& sum3_p0, VQ_T& sum3_p1, VQ_T& sum4_p0, VQ_T& sum4_p1, VQ_T& sum5_p0,
|
||||
VQ_T& sum5_p1, VQ_T& sum6_p0, VQ_T& sum6_p1, VQ_T& sum7_p0, VQ_T& sum7_p1) {
|
||||
const PackedSpan<const T> k_transposed_span =
|
||||
MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim);
|
||||
HWY_DASSERT(kNumQueries <= 8);
|
||||
HWY_DASSERT(gcpp::KVCache::kTileSize >=
|
||||
hn::Lanes(df) * 2); // So we can decompress 2 lanes at a time.
|
||||
sum0_p0 = hn::Zero(df);
|
||||
sum0_p1 = hn::Zero(df);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
sum1_p0 = hn::Zero(df);
|
||||
sum1_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
sum2_p0 = hn::Zero(df);
|
||||
sum2_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
sum3_p0 = hn::Zero(df);
|
||||
sum3_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
sum4_p0 = hn::Zero(df);
|
||||
sum4_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
sum5_p0 = hn::Zero(df);
|
||||
sum5_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
sum6_p0 = hn::Zero(df);
|
||||
sum6_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
sum7_p0 = hn::Zero(df);
|
||||
sum7_p1 = hn::Zero(df);
|
||||
}
|
||||
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
constexpr int kSecondHalfAmountOfQueries =
|
||||
kNumQueries - kFirstHalfAmountOfQueries;
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < qkv_dim; ++i) {
|
||||
VQ_T k_vec1, k_vec2;
|
||||
if constexpr (HWY_TARGET == HWY_AVX2) {
|
||||
hwy::Prefetch(k_transposed_span.ptr + (i + 3) * gcpp::KVCache::kTileSize);
|
||||
hwy::Prefetch(k_transposed_span.ptr + (i + 4) * gcpp::KVCache::kTileSize);
|
||||
}
|
||||
Decompress2(df, k_transposed_span, i * gcpp::KVCache::kTileSize, k_vec1,
|
||||
k_vec2);
|
||||
sum0_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p0);
|
||||
sum0_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 0]), sum0_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
sum1_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p0);
|
||||
sum1_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 1]), sum1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
sum2_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p0);
|
||||
sum2_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 2]), sum2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
sum3_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p0);
|
||||
sum3_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q[i * kFirstHalfAmountOfQueries + 3]), sum3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
sum4_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p0);
|
||||
sum4_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 0]), sum4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
sum5_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p0);
|
||||
sum5_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 1]), sum5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
sum6_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p0);
|
||||
sum6_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 2]), sum6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
sum7_p0 = hn::MulAdd(
|
||||
k_vec1, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p0);
|
||||
sum7_p1 = hn::MulAdd(
|
||||
k_vec2, hn::Set(df, q2[i * kSecondHalfAmountOfQueries + 3]), sum7_p1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename T>
|
||||
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16(
|
||||
DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2,
|
||||
const T* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0,
|
||||
VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1,
|
||||
VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0,
|
||||
VF& sum5_p1, VF& sum6_p0, VF& sum6_p1, VF& sum7_p0, VF& sum7_p1) {
|
||||
using DBF = hn::ScalableTag<BF16>;
|
||||
const DBF dbf;
|
||||
using VBF = hn::Vec<DBF>;
|
||||
const PackedSpan<const T> k_transposed_span =
|
||||
MakeConstSpan(k_transposed_tile, gcpp::KVCache::kTileSize * qkv_dim);
|
||||
[[maybe_unused]] HWY_LANES_CONSTEXPR size_t lanes_bf16 = hn::Lanes(dbf);
|
||||
HWY_DASSERT(hn::Lanes(dbf) <= gcpp::KVCache::kTileSize);
|
||||
HWY_DASSERT(kNumQueries <= 8);
|
||||
HWY_DASSERT(gcpp::KVCache::kTileSize >=
|
||||
hn::Lanes(df) * 2); // So we can decompress 2 lanes at a time.
|
||||
sum0_p0 = hn::Zero(df);
|
||||
sum0_p1 = hn::Zero(df);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
sum1_p0 = hn::Zero(df);
|
||||
sum1_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
sum2_p0 = hn::Zero(df);
|
||||
sum2_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
sum3_p0 = hn::Zero(df);
|
||||
sum3_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
sum4_p0 = hn::Zero(df);
|
||||
sum4_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
sum5_p0 = hn::Zero(df);
|
||||
sum5_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
sum6_p0 = hn::Zero(df);
|
||||
sum6_p1 = hn::Zero(df);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
sum7_p0 = hn::Zero(df);
|
||||
sum7_p1 = hn::Zero(df);
|
||||
}
|
||||
VF helper_sum0_p0 = hn::Zero(df), helper_sum0_p1 = hn::Zero(df);
|
||||
VF helper_sum1_p0 = hn::Zero(df), helper_sum1_p1 = hn::Zero(df);
|
||||
VF helper_sum2_p0 = hn::Zero(df), helper_sum2_p1 = hn::Zero(df);
|
||||
VF helper_sum3_p0 = hn::Zero(df), helper_sum3_p1 = hn::Zero(df);
|
||||
VF helper_sum4_p0 = hn::Zero(df), helper_sum4_p1 = hn::Zero(df);
|
||||
VF helper_sum5_p0 = hn::Zero(df), helper_sum5_p1 = hn::Zero(df);
|
||||
VF helper_sum6_p0 = hn::Zero(df), helper_sum6_p1 = hn::Zero(df);
|
||||
VF helper_sum7_p0 = hn::Zero(df), helper_sum7_p1 = hn::Zero(df);
|
||||
const float* q_float_ptr = HWY_RCAST_ALIGNED(const float*, q);
|
||||
const float* q2_float_ptr = HWY_RCAST_ALIGNED(const float*, q2);
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
constexpr int kSecondHalfAmountOfQueries =
|
||||
kNumQueries - kFirstHalfAmountOfQueries;
|
||||
|
||||
for (size_t i = 0; i < qkv_dim / 2; i++) {
|
||||
VBF k_vec1, k_vec2;
|
||||
Decompress2(dbf, k_transposed_span, i * 2 * gcpp::KVCache::kTileSize,
|
||||
k_vec1, k_vec2);
|
||||
|
||||
VF q_0_as_float = hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries]);
|
||||
VBF q_0 = hn::BitCast(dbf, q_0_as_float);
|
||||
sum0_p0 =
|
||||
hn::ReorderWidenMulAccumulate(df, k_vec1, q_0, sum0_p0, helper_sum0_p0);
|
||||
sum0_p1 =
|
||||
hn::ReorderWidenMulAccumulate(df, k_vec2, q_0, sum0_p1, helper_sum0_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
VF q_1_as_float =
|
||||
hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 1]);
|
||||
VBF q_1 = hn::BitCast(dbf, q_1_as_float);
|
||||
sum1_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_1, sum1_p0,
|
||||
helper_sum1_p0);
|
||||
sum1_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_1, sum1_p1,
|
||||
helper_sum1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
VF q_2_as_float =
|
||||
hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 2]);
|
||||
VBF q_2 = hn::BitCast(dbf, q_2_as_float);
|
||||
sum2_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_2, sum2_p0,
|
||||
helper_sum2_p0);
|
||||
sum2_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_2, sum2_p1,
|
||||
helper_sum2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
VF q_3_as_float =
|
||||
hn::Set(df, q_float_ptr[i * kFirstHalfAmountOfQueries + 3]);
|
||||
VBF q_3 = hn::BitCast(dbf, q_3_as_float);
|
||||
sum3_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_3, sum3_p0,
|
||||
helper_sum3_p0);
|
||||
sum3_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_3, sum3_p1,
|
||||
helper_sum3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
VF q_4_as_float =
|
||||
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 0]);
|
||||
VBF q_4 = hn::BitCast(dbf, q_4_as_float);
|
||||
sum4_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_4, sum4_p0,
|
||||
helper_sum4_p0);
|
||||
sum4_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_4, sum4_p1,
|
||||
helper_sum4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
VF q_5_as_float =
|
||||
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 1]);
|
||||
VBF q_5 = hn::BitCast(dbf, q_5_as_float);
|
||||
sum5_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_5, sum5_p0,
|
||||
helper_sum5_p0);
|
||||
sum5_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_5, sum5_p1,
|
||||
helper_sum5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
VF q_6_as_float =
|
||||
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 2]);
|
||||
VBF q_6 = hn::BitCast(dbf, q_6_as_float);
|
||||
sum6_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_6, sum6_p0,
|
||||
helper_sum6_p0);
|
||||
sum6_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_6, sum6_p1,
|
||||
helper_sum6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
VF q_7_as_float =
|
||||
hn::Set(df, q2_float_ptr[i * kSecondHalfAmountOfQueries + 3]);
|
||||
VBF q_7 = hn::BitCast(dbf, q_7_as_float);
|
||||
sum7_p0 = hn::ReorderWidenMulAccumulate(df, k_vec1, q_7, sum7_p0,
|
||||
helper_sum7_p0);
|
||||
sum7_p1 = hn::ReorderWidenMulAccumulate(df, k_vec2, q_7, sum7_p1,
|
||||
helper_sum7_p1);
|
||||
}
|
||||
}
|
||||
#if HWY_NATIVE_DOT_BF16 == 0
|
||||
sum0_p0 = hn::Add(sum0_p0, helper_sum0_p0);
|
||||
sum0_p1 = hn::Add(sum0_p1, helper_sum0_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
sum1_p0 = hn::Add(sum1_p0, helper_sum1_p0);
|
||||
sum1_p1 = hn::Add(sum1_p1, helper_sum1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
sum2_p0 = hn::Add(sum2_p0, helper_sum2_p0);
|
||||
sum2_p1 = hn::Add(sum2_p1, helper_sum2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
sum3_p0 = hn::Add(sum3_p0, helper_sum3_p0);
|
||||
sum3_p1 = hn::Add(sum3_p1, helper_sum3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
sum4_p0 = hn::Add(sum4_p0, helper_sum4_p0);
|
||||
sum4_p1 = hn::Add(sum4_p1, helper_sum4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
sum5_p0 = hn::Add(sum5_p0, helper_sum5_p0);
|
||||
sum5_p1 = hn::Add(sum5_p1, helper_sum5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
sum6_p0 = hn::Add(sum6_p0, helper_sum6_p0);
|
||||
sum6_p1 = hn::Add(sum6_p1, helper_sum6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
sum7_p0 = hn::Add(sum7_p0, helper_sum7_p0);
|
||||
sum7_p1 = hn::Add(sum7_p1, helper_sum7_p1);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int kVTileSize, class DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void ApplySoftCap(DF df, float att_cap, float one_over_cap,
|
||||
VF& x0, VF& x1, VF& x2, VF& x3, VF& x4,
|
||||
VF& x5, VF& x6, VF& x7) {
|
||||
if (att_cap > 0.0f) {
|
||||
VF cap = hn::Set(df, att_cap);
|
||||
VF one_over_cap_vec = hn::Set(df, one_over_cap);
|
||||
x0 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x0, one_over_cap_vec)));
|
||||
if constexpr (kVTileSize >= 2) {
|
||||
x1 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x1, one_over_cap_vec)));
|
||||
}
|
||||
if constexpr (kVTileSize >= 3) {
|
||||
x2 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x2, one_over_cap_vec)));
|
||||
}
|
||||
if constexpr (kVTileSize >= 4) {
|
||||
x3 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x3, one_over_cap_vec)));
|
||||
}
|
||||
if constexpr (kVTileSize >= 5) {
|
||||
x4 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x4, one_over_cap_vec)));
|
||||
}
|
||||
if constexpr (kVTileSize >= 6) {
|
||||
x5 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x5, one_over_cap_vec)));
|
||||
}
|
||||
if constexpr (kVTileSize >= 7) {
|
||||
x6 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x6, one_over_cap_vec)));
|
||||
}
|
||||
if constexpr (kVTileSize >= 8) {
|
||||
x7 = hn::Mul(cap, hn::CallTanh(df, hn::Mul(x7, one_over_cap_vec)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename DU,
|
||||
class VU = hn::Vec<DU>>
|
||||
static HWY_NOINLINE void ApplyMasking(
|
||||
DF df, DU du, size_t position,
|
||||
const size_t* HWY_RESTRICT first_pos_per_query,
|
||||
const size_t* HWY_RESTRICT last_pos_per_query, VF& x0_p0, VF& x0_p1,
|
||||
VF& x1_p0, VF& x1_p1, VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0,
|
||||
VF& x4_p1, VF& x5_p0, VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0,
|
||||
VF& x7_p1) {
|
||||
VU lane_indices = hn::Iota(du, 0);
|
||||
HWY_LANES_CONSTEXPR size_t kTileSize = hn::Lanes(df);
|
||||
auto per_lane_pos_p0 = hn::Add(hn::Set(du, position), lane_indices);
|
||||
auto per_lane_pos_p1 =
|
||||
hn::Add(hn::Set(du, position + kTileSize), lane_indices);
|
||||
|
||||
VF neg_inf = hn::Set(df, kNegInf);
|
||||
|
||||
auto apply_mask_for_query = [&](int query_idx, VF& x_p0, VF& x_p1) HWY_ATTR {
|
||||
const size_t first_pos = first_pos_per_query[query_idx];
|
||||
const size_t last_pos = last_pos_per_query[query_idx];
|
||||
|
||||
auto valid_tokens_mask_p0 = hn::Ge(per_lane_pos_p0, hn::Set(du, first_pos));
|
||||
valid_tokens_mask_p0 = hn::And(
|
||||
valid_tokens_mask_p0, hn::Le(per_lane_pos_p0, hn::Set(du, last_pos)));
|
||||
x_p0 =
|
||||
hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p0), x_p0, neg_inf);
|
||||
|
||||
auto valid_tokens_mask_p1 = hn::Ge(per_lane_pos_p1, hn::Set(du, first_pos));
|
||||
valid_tokens_mask_p1 = hn::And(
|
||||
valid_tokens_mask_p1, hn::Le(per_lane_pos_p1, hn::Set(du, last_pos)));
|
||||
x_p1 =
|
||||
hn::IfThenElse(hn::RebindMask(df, valid_tokens_mask_p1), x_p1, neg_inf);
|
||||
};
|
||||
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
apply_mask_for_query(0, x0_p0, x0_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
apply_mask_for_query(1, x1_p0, x1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
apply_mask_for_query(2, x2_p0, x2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
apply_mask_for_query(3, x3_p0, x3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
apply_mask_for_query(4, x4_p0, x4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
apply_mask_for_query(5, x5_p0, x5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
apply_mask_for_query(6, x6_p0, x6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
apply_mask_for_query(7, x7_p0, x7_p1);
|
||||
}
|
||||
}
|
||||
|
||||
// Performs tiled flash attention for arbitrary number of queries
|
||||
// It depends on kv being tiled.
|
||||
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
|
||||
// It moves NF*2 timesteps forward in kv at a time.
|
||||
// Args:
|
||||
// kvs - hwy::Span of MatPtrT<KV_T> of shape (kvs, (tile_count, qkv_dim *
|
||||
// kTileSize * 2)) This span allows to pass kv cache that is not contiguous,
|
||||
// all except for the last one should have theirs row count be true,
|
||||
// as it will be used to figure out when to switch to the next one.
|
||||
// q_T_in_groups_up_to_4 - Span of float* All except last float*
|
||||
// should have (qkv_dim, 4) Last one can have any size up to 4.
|
||||
// start_pos_per_query - start position in kv to start attention from ()
|
||||
// last_pos_per_query - last position in kv to attend to (exclusive)
|
||||
// queries_per_timestep - how many queries begin/end on the same timestep
|
||||
// attention_shape - see struct definition for more details.
|
||||
// att_cap - soft cap on attention logits
|
||||
// att_out - MatPtrT<float> of shape (q_count, qkv_dim)
|
||||
// exp_denominator_sums and max_logits: float* of shape:
|
||||
// (RountedUpTo(q_count,4),)
|
||||
// Need to be have multiple of 4 elements alocated and
|
||||
// be initizalized If you need to compute over multiple chunks of kv's you can
|
||||
// keep values between calls to this function and avoid explicit merge.
|
||||
template <typename KV_T, typename Q_T>
|
||||
HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
const hwy::Span<const MatPtrT<KV_T>> kvs, int q_count,
|
||||
const hwy::Span<const Q_T * HWY_RESTRICT> q_T_in_groups_up_to_4,
|
||||
hwy::Span<const size_t> start_pos_per_query,
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
|
||||
float* HWY_RESTRICT max_logits) {
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
using VF = hn::Vec<DF>;
|
||||
using DU = hn::ScalableTag<uint32_t>;
|
||||
[[maybe_unused]] const DU du;
|
||||
constexpr int kTileSize = gcpp::KVCache::kTileSize;
|
||||
HWY_LANES_CONSTEXPR size_t kHTileSize = hn::Lanes(df);
|
||||
constexpr int kNumQueriesPerGroup = 4;
|
||||
constexpr int kNumQueriesPerLoop =
|
||||
(!HWY_ARCH_X86 || (HWY_TARGET <= HWY_AVX3)) ? 8 : 4;
|
||||
constexpr int kNumGroupsPerLoop = kNumQueriesPerLoop / kNumQueriesPerGroup;
|
||||
const size_t full_groups_of_queries = q_count / kNumQueriesPerGroup;
|
||||
const size_t num_loops = hwy::DivCeil(q_count, kNumQueriesPerLoop);
|
||||
const size_t qkv_dim = att_out.Cols();
|
||||
HWY_DASSERT(kHTileSize <= hn::MaxLanes(df));
|
||||
HWY_LANES_CONSTEXPR size_t step_size = kHTileSize * 2;
|
||||
size_t smallest_start_pos = std::numeric_limits<size_t>::max();
|
||||
size_t largest_last_pos = std::numeric_limits<size_t>::min();
|
||||
for (size_t i = 0; i < start_pos_per_query.size(); ++i) {
|
||||
smallest_start_pos = std::min(smallest_start_pos, start_pos_per_query[i]);
|
||||
largest_last_pos = std::max(largest_last_pos, last_pos_per_query[i]);
|
||||
}
|
||||
// start / end positions per group of 4 queries.
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> pos_data(num_loops * 4);
|
||||
hwy::Span<size_t> min_start_pos_per_group(pos_data.data(), num_loops);
|
||||
hwy::Span<size_t> max_start_pos_per_group(pos_data.data() + num_loops,
|
||||
num_loops);
|
||||
hwy::Span<size_t> min_last_pos_per_group(pos_data.data() + 2 * num_loops,
|
||||
num_loops);
|
||||
hwy::Span<size_t> max_last_pos_per_group(pos_data.data() + 3 * num_loops,
|
||||
num_loops);
|
||||
|
||||
for (size_t i = 0; i < num_loops; ++i) {
|
||||
size_t min_start = std::numeric_limits<size_t>::max();
|
||||
size_t max_start = 0;
|
||||
size_t min_last = std::numeric_limits<size_t>::max();
|
||||
size_t max_last = 0;
|
||||
for (int j = 0; j < kNumQueriesPerLoop; ++j) {
|
||||
if (i * kNumQueriesPerLoop + j < q_count) {
|
||||
min_start = std::min(min_start,
|
||||
start_pos_per_query[i * kNumQueriesPerLoop + j]);
|
||||
max_start = std::max(max_start,
|
||||
start_pos_per_query[i * kNumQueriesPerLoop + j]);
|
||||
min_last =
|
||||
std::min(min_last, last_pos_per_query[i * kNumQueriesPerLoop + j]);
|
||||
max_last =
|
||||
std::max(max_last, last_pos_per_query[i * kNumQueriesPerLoop + j]);
|
||||
}
|
||||
}
|
||||
min_start_pos_per_group[i] = min_start;
|
||||
max_start_pos_per_group[i] = max_start;
|
||||
min_last_pos_per_group[i] = min_last;
|
||||
max_last_pos_per_group[i] = max_last;
|
||||
}
|
||||
const size_t base_pos = smallest_start_pos - (smallest_start_pos % kTileSize);
|
||||
const size_t rem = smallest_start_pos % kTileSize;
|
||||
const size_t num_skipped_sub_tiles = rem / step_size;
|
||||
size_t position = base_pos + num_skipped_sub_tiles * step_size;
|
||||
[[maybe_unused]] float one_over_cap = 1.0f / att_cap;
|
||||
std::vector<MatPtrT<float>> att_out_per_query;
|
||||
att_out_per_query.reserve(num_loops);
|
||||
for (size_t i = 0; i < num_loops; ++i) {
|
||||
att_out_per_query.emplace_back("att_out",
|
||||
Extents2D(kNumQueriesPerLoop, qkv_dim));
|
||||
att_out_per_query.back().SetPtr(att_out.Row(i * kNumQueriesPerLoop),
|
||||
att_out.Stride());
|
||||
}
|
||||
size_t current_kv_start_offset = 0;
|
||||
size_t current_kv_idx = 0;
|
||||
|
||||
auto inner_loop = [&]<int kNumQueries>(int q_group_idx) HWY_ATTR {
|
||||
int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup);
|
||||
if (position + step_size <= min_start_pos_per_group[loop_idx] ||
|
||||
position > max_last_pos_per_group[loop_idx]) {
|
||||
return;
|
||||
}
|
||||
VF x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1;
|
||||
VF x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1;
|
||||
const size_t pos_in_tile = position % kTileSize;
|
||||
// tile base can point to same tile as previous loop iteration, hence no
|
||||
// HWY_RESTRICT
|
||||
// KVs are unaligned and we only use unaligned loads in this implementation.
|
||||
const KV_T* tile_base =
|
||||
reinterpret_cast<const KV_T*>(kvs[current_kv_idx].RowBytes(
|
||||
(position - current_kv_start_offset) / kTileSize));
|
||||
|
||||
const KV_T* v_tile =
|
||||
tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim;
|
||||
const Q_T* q_group = q_T_in_groups_up_to_4[q_group_idx];
|
||||
const Q_T* q2_group = nullptr;
|
||||
if (kNumQueries > 4) {
|
||||
q2_group = q_T_in_groups_up_to_4[q_group_idx + 1];
|
||||
}
|
||||
if constexpr (IsF32<Q_T>()) {
|
||||
const KV_T* k_transposed_tile = tile_base + pos_in_tile;
|
||||
QDotKTilexUpTo8TransposedKDoubleWidth<kNumQueries>(
|
||||
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
|
||||
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
|
||||
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
} else if constexpr (IsBF16<Q_T>()) {
|
||||
const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2;
|
||||
QDotKTilexUpTo8TransposedKDoubleWidthBF16<kNumQueries>(
|
||||
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
|
||||
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
|
||||
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
} else {
|
||||
static_assert(
|
||||
false,
|
||||
"Query type type not supported, only float and BF16 are supported");
|
||||
}
|
||||
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
constexpr int kSecondHalfAmountOfQueries =
|
||||
kNumQueries - kFirstHalfAmountOfQueries;
|
||||
ApplySoftCap<kFirstHalfAmountOfQueries * 2>(
|
||||
df, att_cap, one_over_cap, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0,
|
||||
x_2_p_1, x_3_p_0, x_3_p_1);
|
||||
if constexpr (kNumQueries > 4) {
|
||||
ApplySoftCap<kSecondHalfAmountOfQueries * 2>(
|
||||
df, att_cap, one_over_cap, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
|
||||
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
}
|
||||
|
||||
if (position < max_start_pos_per_group[loop_idx] ||
|
||||
position + step_size - 1 > min_last_pos_per_group[loop_idx]) {
|
||||
ApplyMasking<kNumQueries>(
|
||||
df, du, position,
|
||||
start_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup,
|
||||
last_pos_per_query.data() + q_group_idx * kNumQueriesPerGroup,
|
||||
x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
|
||||
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
|
||||
x_7_p_0, x_7_p_1);
|
||||
}
|
||||
HWY_ALIGN float scales[kNumQueriesPerLoop];
|
||||
// HWY_UNROLL(kNumQueriesPerLoop)
|
||||
for (size_t i = 0; i < kNumQueriesPerLoop; ++i) {
|
||||
scales[i] = 1.0f;
|
||||
}
|
||||
FlashAttentionTileStepAndApplySoftCap<kNumQueries>(
|
||||
df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
|
||||
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
|
||||
kNumQueriesPerGroup);
|
||||
if constexpr (IsF32<Q_T>()) {
|
||||
MulByConstAndAddTileUpTo8<kNumQueries>(
|
||||
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
|
||||
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
|
||||
} else if constexpr (IsBF16<Q_T>()) {
|
||||
MulByConstAndAddTileUpTo8_BF16<kNumQueries>(
|
||||
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
|
||||
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
|
||||
}
|
||||
};
|
||||
|
||||
while (position <= largest_last_pos) {
|
||||
while (position - current_kv_start_offset >=
|
||||
kvs[current_kv_idx].Rows() * kTileSize) {
|
||||
current_kv_start_offset += kvs[current_kv_idx].Rows() * kTileSize;
|
||||
current_kv_idx++;
|
||||
}
|
||||
int group_idx = 0;
|
||||
for (; group_idx + kNumGroupsPerLoop <= full_groups_of_queries;
|
||||
group_idx += kNumGroupsPerLoop) {
|
||||
inner_loop.template operator()<kNumQueriesPerLoop>(group_idx);
|
||||
}
|
||||
if (group_idx < full_groups_of_queries) {
|
||||
inner_loop.template operator()<4>(group_idx);
|
||||
group_idx++;
|
||||
}
|
||||
switch (q_count % kNumQueriesPerGroup) {
|
||||
case 1:
|
||||
inner_loop.template operator()<1>(group_idx);
|
||||
break;
|
||||
case 2:
|
||||
inner_loop.template operator()<2>(group_idx);
|
||||
break;
|
||||
case 3:
|
||||
inner_loop.template operator()<3>(group_idx);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
position += step_size;
|
||||
}
|
||||
}
|
||||
|
||||
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
hwy::Span<const MatPtr> kvs, int q_count,
|
||||
const hwy::Span<const float* HWY_RESTRICT> q_T_in_groups_up_to_4,
|
||||
hwy::Span<const size_t> start_pos_per_query,
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
|
||||
float* HWY_RESTRICT max_logits) {
|
||||
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
|
||||
return TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
|
||||
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
|
||||
});
|
||||
}
|
||||
|
||||
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
||||
hwy::Span<const MatPtr> kvs, int q_count,
|
||||
const hwy::Span<const BF16 * HWY_RESTRICT> q_T_in_groups_up_to_4,
|
||||
hwy::Span<const size_t> start_pos_per_query,
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
|
||||
float* HWY_RESTRICT max_logits) {
|
||||
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
|
||||
return TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
|
||||
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
|
||||
});
|
||||
}
|
||||
|
||||
// Rounds n to a number that can be used as the number of Q rows in a tile
|
||||
// of flash attention.
|
||||
static size_t RoundToSuitablePowerOf2(size_t n) {
|
||||
|
|
|
|||
|
|
@ -22,46 +22,78 @@
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/flash_structs.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/query.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const BF16* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, const AttentionActivations& activations, \
|
||||
MatPtrT<float>& att_out, const uint32_t* HWY_RESTRICT out_offsets, \
|
||||
ThreadingContext& ctx, const size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx, AttentionImpl attention_impl); \
|
||||
\
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void RMSNormAndPositionalEncoding( \
|
||||
size_t num_tokens, const QBatch& qbatch, MatPtrT<float>& q, \
|
||||
const MatPtr& query_norm_scale, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \
|
||||
\
|
||||
void SingleFlashAttention(size_t start_pos, size_t last_pos, \
|
||||
const BF16* HWY_RESTRICT q, \
|
||||
const MatPtrT<KV_t>& k, const MatPtrT<KV_t>& v, \
|
||||
size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, \
|
||||
float* HWY_RESTRICT att_out, \
|
||||
ThreadingContext& ctx, size_t worker); \
|
||||
\
|
||||
Tile4FlashState TileFlashAttention4( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const MatPtrT<KV_t>& k, size_t start_pos, \
|
||||
const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \
|
||||
size_t max_last_pos, const MatPtrT<KV_t>& v, size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
|
||||
const size_t worker); \
|
||||
\
|
||||
void TileFlashAttention( \
|
||||
const MatPtrT<BF16>& q, const uint32_t* HWY_RESTRICT q_offsets, \
|
||||
const StridedView<BF16>& qT, const MatPtrT<KV_t>& k, \
|
||||
const size_t start_pos, const uint32_t* HWY_RESTRICT last_pos, \
|
||||
const size_t min_last_pos, const size_t max_last_pos, \
|
||||
const MatPtrT<KV_t>& v, const size_t layer_idx, \
|
||||
const AttentionActivationsPtrs& activations, MatPtrT<float>& att_out, \
|
||||
const uint32_t* HWY_RESTRICT out_offsets, ThreadingContext& ctx, \
|
||||
const size_t worker); \
|
||||
\
|
||||
size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \
|
||||
size_t total_tasks, size_t target_parallelism); \
|
||||
\
|
||||
void FlashAttention(size_t num_tokens, size_t target_parallelism, \
|
||||
size_t layer_idx, const MatPtr& query_norm_scale, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx, AttentionImpl attention_impl); \
|
||||
\
|
||||
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits( \
|
||||
hwy::Span<const MatPtr> kvs, int q_count, \
|
||||
const hwy::Span<const float* HWY_RESTRICT> q_T_in_groups_up_to_4, \
|
||||
hwy::Span<const size_t> start_pos_per_query, \
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
|
||||
float* HWY_RESTRICT max_logits); \
|
||||
\
|
||||
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( \
|
||||
hwy::Span<const MatPtr> kvs, int q_count, \
|
||||
const hwy::Span<const BF16 * HWY_RESTRICT> q_T_in_groups_up_to_4, \
|
||||
hwy::Span<const size_t> start_pos_per_query, \
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
|
||||
float* HWY_RESTRICT max_logits); \
|
||||
\
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
|
|
|
|||
|
|
@ -181,6 +181,298 @@ void TestAttention() {
|
|||
TestFlashAttention(256);
|
||||
}
|
||||
|
||||
const std::vector<float> exp_denominator_sums_gold = {
|
||||
58.722088f, 58.445938f, 58.17153f, 57.89886f,
|
||||
58.580994f, 58.302643f, 58.026085f, 57.751308f};
|
||||
const std::vector<float> max_logits_gold = {
|
||||
0.009613638f, 0.019227259f, 0.02884084f, 0.038454376f,
|
||||
0.04888253f, 0.058658823f, 0.06843502f, 0.078211054f};
|
||||
const std::vector<float> att_out_gold = {
|
||||
0.600945, 0.300472, 0.200315, 0.150236, 0.120189, 0.100158, 0.085849,
|
||||
0.075118, 0.066772, 0.060095, 0.054631, 0.050079, 0.046227, 0.042925,
|
||||
0.040063, 0.037559, 0.035350, 0.033386, 0.031629, 0.030047, 0.028616,
|
||||
0.027316, 0.026128, 0.025039, 0.024038, 0.023113, 0.022257, 0.021462,
|
||||
0.020722, 0.020032, 0.019385, 0.018780, 0.018210, 0.017675, 0.017170,
|
||||
0.016693, 0.016242, 0.015814, 0.015409, 0.015024, 0.014657, 0.014308,
|
||||
0.013975, 0.013658, 0.013354, 0.013064, 0.012786, 0.012520, 0.012264,
|
||||
0.012019, 0.011783, 0.011557, 0.011339, 0.011129, 0.010926, 0.010731,
|
||||
0.010543, 0.010361, 0.010186, 0.010016, 0.009852, 0.009693, 0.009539,
|
||||
0.009390, 0.601890, 0.300945, 0.200630, 0.150473, 0.120378, 0.100315,
|
||||
0.085984, 0.075236, 0.066877, 0.060189, 0.054717, 0.050158, 0.046299,
|
||||
0.042992, 0.040126, 0.037618, 0.035405, 0.033438, 0.031678, 0.030095,
|
||||
0.028661, 0.027359, 0.026169, 0.025079, 0.024076, 0.023150, 0.022292,
|
||||
0.021496, 0.020755, 0.020063, 0.019416, 0.018809, 0.018239, 0.017703,
|
||||
0.017197, 0.016719, 0.016267, 0.015839, 0.015433, 0.015047, 0.014680,
|
||||
0.014331, 0.013997, 0.013679, 0.013375, 0.013085, 0.012806, 0.012539,
|
||||
0.012283, 0.012038, 0.011802, 0.011575, 0.011356, 0.011146, 0.010943,
|
||||
0.010748, 0.010559, 0.010377, 0.010202, 0.010032, 0.009867, 0.009708,
|
||||
0.009554, 0.009405, 0.602835, 0.301418, 0.200945, 0.150709, 0.120567,
|
||||
0.100473, 0.086119, 0.075354, 0.066982, 0.060284, 0.054803, 0.050236,
|
||||
0.046372, 0.043060, 0.040189, 0.037677, 0.035461, 0.033491, 0.031728,
|
||||
0.030142, 0.028706, 0.027402, 0.026210, 0.025118, 0.024113, 0.023186,
|
||||
0.022327, 0.021530, 0.020787, 0.020095, 0.019446, 0.018839, 0.018268,
|
||||
0.017730, 0.017224, 0.016745, 0.016293, 0.015864, 0.015457, 0.015071,
|
||||
0.014703, 0.014353, 0.014019, 0.013701, 0.013396, 0.013105, 0.012826,
|
||||
0.012559, 0.012303, 0.012057, 0.011820, 0.011593, 0.011374, 0.011164,
|
||||
0.010961, 0.010765, 0.010576, 0.010394, 0.010218, 0.010047, 0.009883,
|
||||
0.009723, 0.009569, 0.009419, 0.603780, 0.301890, 0.201260, 0.150945,
|
||||
0.120756, 0.100630, 0.086254, 0.075473, 0.067087, 0.060378, 0.054889,
|
||||
0.050315, 0.046445, 0.043127, 0.040252, 0.037736, 0.035516, 0.033543,
|
||||
0.031778, 0.030189, 0.028751, 0.027445, 0.026251, 0.025158, 0.024151,
|
||||
0.023222, 0.022362, 0.021564, 0.020820, 0.020126, 0.019477, 0.018868,
|
||||
0.018296, 0.017758, 0.017251, 0.016772, 0.016318, 0.015889, 0.015482,
|
||||
0.015095, 0.014726, 0.014376, 0.014041, 0.013722, 0.013417, 0.013126,
|
||||
0.012846, 0.012579, 0.012322, 0.012076, 0.011839, 0.011611, 0.011392,
|
||||
0.011181, 0.010978, 0.010782, 0.010593, 0.010410, 0.010234, 0.010063,
|
||||
0.009898, 0.009738, 0.009584, 0.009434, 0.614887, 0.307443, 0.204962,
|
||||
0.153722, 0.122977, 0.102481, 0.087841, 0.076861, 0.068321, 0.061489,
|
||||
0.055899, 0.051241, 0.047299, 0.043920, 0.040992, 0.038430, 0.036170,
|
||||
0.034160, 0.032362, 0.030744, 0.029280, 0.027949, 0.026734, 0.025620,
|
||||
0.024595, 0.023649, 0.022774, 0.021960, 0.021203, 0.020496, 0.019835,
|
||||
0.019215, 0.018633, 0.018085, 0.017568, 0.017080, 0.016619, 0.016181,
|
||||
0.015766, 0.015372, 0.014997, 0.014640, 0.014300, 0.013975, 0.013664,
|
||||
0.013367, 0.013083, 0.012810, 0.012549, 0.012298, 0.012057, 0.011825,
|
||||
0.011602, 0.011387, 0.011180, 0.010980, 0.010787, 0.010601, 0.010422,
|
||||
0.010248, 0.010080, 0.009918, 0.009760, 0.009608, 0.615864, 0.307932,
|
||||
0.205288, 0.153966, 0.123173, 0.102644, 0.087981, 0.076983, 0.068429,
|
||||
0.061586, 0.055988, 0.051322, 0.047374, 0.043990, 0.041058, 0.038491,
|
||||
0.036227, 0.034215, 0.032414, 0.030793, 0.029327, 0.027994, 0.026777,
|
||||
0.025661, 0.024635, 0.023687, 0.022810, 0.021995, 0.021237, 0.020529,
|
||||
0.019867, 0.019246, 0.018663, 0.018114, 0.017596, 0.017107, 0.016645,
|
||||
0.016207, 0.015791, 0.015397, 0.015021, 0.014663, 0.014322, 0.013997,
|
||||
0.013686, 0.013388, 0.013103, 0.012830, 0.012569, 0.012317, 0.012076,
|
||||
0.011844, 0.011620, 0.011405, 0.011198, 0.010998, 0.010805, 0.010618,
|
||||
0.010438, 0.010264, 0.010096, 0.009933, 0.009776, 0.009623, 0.616841,
|
||||
0.308421, 0.205614, 0.154210, 0.123368, 0.102807, 0.088120, 0.077105,
|
||||
0.068538, 0.061684, 0.056076, 0.051403, 0.047449, 0.044060, 0.041123,
|
||||
0.038553, 0.036285, 0.034269, 0.032465, 0.030842, 0.029373, 0.028038,
|
||||
0.026819, 0.025702, 0.024674, 0.023725, 0.022846, 0.022030, 0.021270,
|
||||
0.020561, 0.019898, 0.019276, 0.018692, 0.018142, 0.017624, 0.017134,
|
||||
0.016671, 0.016233, 0.015816, 0.015421, 0.015045, 0.014687, 0.014345,
|
||||
0.014019, 0.013708, 0.013410, 0.013124, 0.012851, 0.012589, 0.012337,
|
||||
0.012095, 0.011862, 0.011639, 0.011423, 0.011215, 0.011015, 0.010822,
|
||||
0.010635, 0.010455, 0.010281, 0.010112, 0.009949, 0.009791, 0.009638,
|
||||
0.617818, 0.308909, 0.205939, 0.154455, 0.123564, 0.102970, 0.088260,
|
||||
0.077227, 0.068646, 0.061782, 0.056165, 0.051485, 0.047524, 0.044130,
|
||||
0.041188, 0.038614, 0.036342, 0.034323, 0.032517, 0.030891, 0.029420,
|
||||
0.028083, 0.026862, 0.025742, 0.024713, 0.023762, 0.022882, 0.022065,
|
||||
0.021304, 0.020594, 0.019930, 0.019307, 0.018722, 0.018171, 0.017652,
|
||||
0.017162, 0.016698, 0.016258, 0.015841, 0.015445, 0.015069, 0.014710,
|
||||
0.014368, 0.014041, 0.013729, 0.013431, 0.013145, 0.012871, 0.012609,
|
||||
0.012356, 0.012114, 0.011881, 0.011657, 0.011441, 0.011233, 0.011032,
|
||||
0.010839, 0.010652, 0.010471, 0.010297, 0.010128, 0.009965, 0.009807,
|
||||
0.009653};
|
||||
|
||||
void TestTiledFlashAttention() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
|
||||
// tiles size to test the padding logic.
|
||||
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
|
||||
float att_cap = 10.0f;
|
||||
int num_queries = 8;
|
||||
int num_queries_per_timestep = 4;
|
||||
int num_tokens = num_queries / num_queries_per_timestep;
|
||||
int kv_seq_end =
|
||||
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
MatStorageT<float> kv(
|
||||
"kv",
|
||||
Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
// fill in kvs with predictable, synthetic data
|
||||
for (int i = 0; i < padded_kv_seq_len; ++i) {
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
const int tile_idx = i / gcpp::KVCache::kTileSize;
|
||||
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
|
||||
const float val_k = 0.01f * (i + 1) / (j + 1);
|
||||
const float val_v = 0.02f * (i + 1) / (j + 1);
|
||||
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset] = val_k;
|
||||
const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
|
||||
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j] = val_v;
|
||||
}
|
||||
}
|
||||
std::vector<float> q_float(4 * qkv_dim);
|
||||
std::vector<float> q_float2(4 * qkv_dim);
|
||||
// fill in qs with predictable, synthetic data
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
for (int j = 0; j < qkv_dim; j++) {
|
||||
float val_1 = 0.01f * (i + 1) / (j + 1);
|
||||
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
|
||||
q_float[j * 4 + i] = val_1;
|
||||
q_float2[j * 4 + i] = val_2;
|
||||
}
|
||||
}
|
||||
const float* q_T[2] = {q_float.data(), q_float2.data()};
|
||||
|
||||
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
|
||||
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
|
||||
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
|
||||
std::vector<float> max_logits(num_queries_rounded_to_laness);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(i),
|
||||
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
|
||||
exp_denominator_sums[i] = 0.0f;
|
||||
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
|
||||
}
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
|
||||
start_pos_per_query.reserve(num_queries);
|
||||
last_pos_per_query.reserve(num_queries);
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
ssize_t query_last_pos = kv_seq_end + token_idx;
|
||||
ssize_t query_start_pos =
|
||||
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
|
||||
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
|
||||
++q_head_idx) {
|
||||
start_pos_per_query.push_back(query_start_pos);
|
||||
last_pos_per_query.push_back(query_last_pos);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::Span<const MatPtr> kvs(&kv, 1);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kvs, num_queries, hwy::Span<const float*>(q_T, 2),
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
|
||||
exp_denominator_sums.data(), max_logits.data());
|
||||
|
||||
// TODO: Replace with Other implementation for generating goldens.
|
||||
// Current values are taken from a point in time where code was run with gemma
|
||||
// and output looked good. Not ideal but should be good enough to test the
|
||||
// plumbing and detect regressions.
|
||||
PrintMatPtr(att_out);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestTiledFlashAttentionBF16() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
|
||||
// tiles size to test the padding logic.
|
||||
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
|
||||
float att_cap = 10.0f;
|
||||
int num_queries = 8;
|
||||
int num_queries_per_timestep = 4;
|
||||
int num_tokens = num_queries / num_queries_per_timestep;
|
||||
int kv_seq_end =
|
||||
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
MatStorageT<BF16> kv(
|
||||
"kv",
|
||||
Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
// fill in kvs with predictable, synthetic data
|
||||
for (int i = 0; i < padded_kv_seq_len; i++) {
|
||||
for (int j = 0; j < qkv_dim; j+=2) {
|
||||
const int tile_idx = i / gcpp::KVCache::kTileSize;
|
||||
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
|
||||
const float val_k_1 = 0.01f * (i + 1) / (j + 1);
|
||||
const float val_k_2 = 0.01f * (i + 1) / (j + 2);
|
||||
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(val_k_1);
|
||||
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(val_k_2);
|
||||
}
|
||||
}
|
||||
const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
|
||||
for (int i = 0; i < padded_kv_seq_len; i += 2) {
|
||||
for (int j = 0; j < qkv_dim; j++) {
|
||||
const int tile_idx = i / gcpp::KVCache::kTileSize;
|
||||
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
|
||||
const float val_v_1 = 0.02f * (i + 1) / (j + 1);
|
||||
const float val_v_2 = 0.02f * (i + 2) / (j + 1);
|
||||
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(val_v_1);
|
||||
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(val_v_2);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<BF16> q_float(num_queries_per_timestep * qkv_dim);
|
||||
std::vector<BF16> q_float2(num_queries_per_timestep * qkv_dim);
|
||||
// fill in qs with predictable, synthetic data
|
||||
for (int i = 0; i < num_queries_per_timestep; ++i) {
|
||||
for (int j = 0; j < qkv_dim; j += 2) {
|
||||
q_float[j * num_queries_per_timestep + i * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 1));
|
||||
q_float[j * num_queries_per_timestep + i * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 2));
|
||||
|
||||
q_float2[j * num_queries_per_timestep + i * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
0.01f * (i + num_queries_per_timestep + 1) / (j + 1));
|
||||
q_float2[j * num_queries_per_timestep + i * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
0.01f * (i + num_queries_per_timestep + 1) / (j + 2));
|
||||
}
|
||||
}
|
||||
const BF16* q_T[2] = {q_float.data(), q_float2.data()};
|
||||
|
||||
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
|
||||
HWY_LANES_CONSTEXPR size_t lanes = 4;
|
||||
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
|
||||
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
|
||||
std::vector<float> max_logits(num_queries_rounded_to_laness);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(i),
|
||||
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
|
||||
exp_denominator_sums[i] = 0.0f;
|
||||
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
|
||||
}
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
|
||||
start_pos_per_query.reserve(num_queries);
|
||||
last_pos_per_query.reserve(num_queries);
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
ssize_t query_last_pos = kv_seq_end + token_idx;
|
||||
ssize_t query_start_pos =
|
||||
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
|
||||
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
|
||||
++q_head_idx) {
|
||||
start_pos_per_query.push_back(query_start_pos);
|
||||
last_pos_per_query.push_back(query_last_pos);
|
||||
}
|
||||
}
|
||||
hwy::Span<const MatPtr> kvs(&kv, 1);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
||||
kvs, num_queries, hwy::Span<const BF16*>(q_T, 2),
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
|
||||
exp_denominator_sums.data(), max_logits.data());
|
||||
|
||||
// TODO: Replace with Other implementation for generating goldens.
|
||||
// Current values are taken from a point in time where code was run with gemma
|
||||
// and output looked good. Not ideal but should be good enough to test the
|
||||
// plumbing and detect regressions.
|
||||
PrintMatPtr(att_out);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@
|
|||
// After highway.h
|
||||
#include "gemma/attention.h" // includes highway.h
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "gemma/vit.h" // includes highway.h
|
||||
#include "gemma/tiled_attention.h" // includes highway.h
|
||||
#include "gemma/vit.h" // includes highway.h
|
||||
|
||||
#ifndef GEMMA_CC_ONCE
|
||||
#define GEMMA_CC_ONCE
|
||||
|
|
@ -80,6 +81,14 @@ namespace HWY_NAMESPACE {
|
|||
void Attention(LayerAttentionType type, const size_t num_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
if (activations.attention_impl == AttentionImpl::kFlashTransposedQs ||
|
||||
activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
TiledAttention(
|
||||
activations.attention_impl, num_tokens, layer_idx, layer,
|
||||
activations.attention, qbatch, env,
|
||||
AttentionImplToFlags(activations.attention_impl, HWY_NATIVE_DOT_BF16));
|
||||
return;
|
||||
}
|
||||
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
// TODO: remove flag to enable FlashAttention.
|
||||
|
|
|
|||
|
|
@ -148,6 +148,14 @@ struct RuntimeConfig {
|
|||
// Which attention implementation to use.
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlash;
|
||||
|
||||
// Right now it only work for tiled kv cache, implementations.
|
||||
// If not set, it will be set based on the attention_impl.
|
||||
// F32 for tiled
|
||||
// BF16 for tiled bf16
|
||||
// If you want to use type other than F32 or BF16, you might need to update
|
||||
// call upcasted.
|
||||
std::optional<Type> kv_cache_type = {};
|
||||
|
||||
// Functions operating on the generated tokens.
|
||||
StreamFunc stream_token;
|
||||
BatchStreamFunc batch_stream_token;
|
||||
|
|
|
|||
|
|
@ -51,10 +51,72 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
|
||||
allocator) {}
|
||||
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Allocator& allocator)
|
||||
: allocator_(allocator) {
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs ||
|
||||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
) {
|
||||
const size_t num_tiles =
|
||||
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
|
||||
tiled_seq_len = num_tiles * kTileSize;
|
||||
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
|
||||
Type kv_cache_type;
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
) {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16);
|
||||
} else {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
|
||||
}
|
||||
auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size,
|
||||
size_t max_seq_len) {
|
||||
return hwy::DivCeil(
|
||||
std::min(max_seq_len, window_size + prefill_tbatch_size), kTileSize);
|
||||
};
|
||||
|
||||
size_t total_num_tiles = 0;
|
||||
for (size_t window_size : config.attention_window_sizes) {
|
||||
total_num_tiles +=
|
||||
num_tiles_per_head(window_size, runtime_config.prefill_tbatch_size,
|
||||
config.max_seq_len) *
|
||||
config.layer_configs[0].kv_heads;
|
||||
}
|
||||
Extents2D extents(total_num_tiles, tile_length);
|
||||
compact_kv_cache_ptr = MatPtr("kv_tiled", kv_cache_type, extents);
|
||||
compact_kv_cache.AllocateFor(compact_kv_cache_ptr, allocator,
|
||||
MatPadding::kPacked);
|
||||
total_num_tiles = 0;
|
||||
kv_head_ptrs.reserve(config.attention_window_sizes.size() *
|
||||
config.layer_configs[0].kv_heads);
|
||||
for (size_t window_size : config.attention_window_sizes) {
|
||||
for (size_t kv = 0; kv < config.layer_configs[0].kv_heads; ++kv) {
|
||||
size_t num_tiles_per_kv_head =
|
||||
num_tiles_per_head(window_size, runtime_config.prefill_tbatch_size,
|
||||
config.max_seq_len);
|
||||
MatPtr kv_ptr("kv_ptr", kv_cache_type,
|
||||
Extents2D(num_tiles_per_kv_head, tile_length));
|
||||
kv_ptr.SetPtr(compact_kv_cache_ptr.RowBytes(total_num_tiles),
|
||||
compact_kv_cache_ptr.Stride());
|
||||
kv_head_ptrs.emplace_back(std::move(kv_ptr));
|
||||
total_num_tiles += num_tiles_per_kv_head;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
kv_cache = MatStorageT<KV_t>(
|
||||
"kv",
|
||||
Extents2D(CappedSeqLen(config, inference_args), config.KVCacheCols()),
|
||||
allocator, MatPadding::kOdd);
|
||||
}
|
||||
}
|
||||
|
||||
KVCache KVCache::Copy() {
|
||||
KVCache copy(kv_cache.Extents(), allocator_);
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
|
||||
CopyMat(compact_kv_cache_ptr, copy.compact_kv_cache_ptr);
|
||||
copy.tiled_seq_len = tiled_seq_len;
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,31 +31,103 @@
|
|||
namespace gcpp {
|
||||
|
||||
using KV_t = float;
|
||||
struct KVCache;
|
||||
|
||||
// A non-owning view of a KVCache.
|
||||
struct KVCachePtr {
|
||||
bool IsEmpty() const { return kv_cache.Rows() == 0; }
|
||||
size_t SeqLen() const;
|
||||
|
||||
bool IsTiled() const;
|
||||
MatPtrT<KV_t> kv_cache;
|
||||
KVCache* cache = nullptr;
|
||||
};
|
||||
|
||||
struct KVCache {
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const Allocator& allocator);
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
||||
const RuntimeConfig& runtime_config, const Allocator& allocator);
|
||||
// Returns a deep copy of the KVCache. Use explicit function instead of
|
||||
// copy ctor to make the cost explicit.
|
||||
KVCache Copy();
|
||||
|
||||
size_t SeqLen() const {
|
||||
if (IsTiled()) {
|
||||
return tiled_seq_len.value();
|
||||
}
|
||||
return kv_cache.Rows();
|
||||
}
|
||||
|
||||
bool IsTiled() const {
|
||||
return tiled_seq_len.has_value();
|
||||
}
|
||||
|
||||
// This function returns a vector of pointers and handles wraparound for local
|
||||
// layers.
|
||||
// You can use this function to get kv's,
|
||||
// it will slice internal circular buffer and give you parts of it that are in
|
||||
// order. Keep in mind that this gives out pointers to tiles, and for local
|
||||
// layers start_pos might be in a middle of the first tile. At start_pos %
|
||||
// kTileSize
|
||||
std::vector<MatPtr> GetPointers(int layer_idx, int kv_head_idx,
|
||||
int num_kv_heads, int start_pos,
|
||||
bool is_global_layer) {
|
||||
if (!IsTiled()) {
|
||||
HWY_ABORT("This function is only meant to be used with tiled KV caches.");
|
||||
}
|
||||
MatPtr& source_ptr = kv_head_ptrs[layer_idx * num_kv_heads + kv_head_idx];
|
||||
if (is_global_layer) {
|
||||
return {source_ptr};
|
||||
}
|
||||
size_t start_tile_mod_window = (start_pos / kTileSize) % source_ptr.Rows();
|
||||
size_t start_len = source_ptr.Rows() - start_tile_mod_window;
|
||||
MatPtr start_ptr("kv_start", source_ptr.GetType(),
|
||||
Extents2D(start_len, source_ptr.Cols()));
|
||||
start_ptr.SetPtr(source_ptr.RowBytes(start_tile_mod_window),
|
||||
source_ptr.Cols());
|
||||
return {start_ptr, source_ptr};
|
||||
}
|
||||
|
||||
static constexpr size_t kTileSize = 32;
|
||||
std::optional<uint32_t> tiled_seq_len = std::nullopt;
|
||||
// Default Format
|
||||
// If tiled_seq_len is not set, then the kv_cache is assumed to be [seq_len,
|
||||
// layers * kv_heads * qkv_dim * 2].
|
||||
//
|
||||
// Tiled Format
|
||||
// If tiled_seq_len is set, the kv cache is stored in tiled format.
|
||||
// Allocations must happen in full tiles.
|
||||
// The order of dimensions on rows is: [layer, kv_head, tile].
|
||||
// The total number of rows is:
|
||||
// num_layers * num_kv_heads * (tiled_seq_len / kTileSize).
|
||||
// Each tile (containing kTileSize elements from the sequence) can be thought
|
||||
// of as storing K^T and V, where K is shaped [kTileSize, qkv_dim].
|
||||
|
||||
// Type erased kv cache. It's compact because local layers are allocated as
|
||||
// circular buffers.
|
||||
MatPtr compact_kv_cache_ptr;
|
||||
MatOwner compact_kv_cache;
|
||||
// Pointers to the raw KV storage indexed by layer and head. This helps
|
||||
// accessing the tiles even though different layers may have a different
|
||||
// number of tiles in storage. All pointers point into compact_kv_cache.
|
||||
|
||||
// To access the tiles of (layer_idx, head_idx), index the array with
|
||||
// layer_idx * num_kv_heads + kv_head_idx.
|
||||
// Or use GetPointers function.
|
||||
|
||||
// The returned MatPtr will have one tile per row. The number of rows for
|
||||
// global layers is max_seq_len/kTileSize. For local layers it is slightly
|
||||
// more than attention_window_size[layer_idx] / kTileSize. For local layers, a
|
||||
// given token_idx is in row (token_idx / kTileSize) %
|
||||
// kv_head_ptrs[...].Rows().
|
||||
std::vector<MatPtr> kv_head_ptrs;
|
||||
MatStorageT<KV_t> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
KVCachePtr ToPtr() {
|
||||
return KVCachePtr{
|
||||
.kv_cache = kv_cache,
|
||||
.cache = this,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -67,9 +139,17 @@ struct KVCache {
|
|||
};
|
||||
|
||||
inline size_t KVCachePtr::SeqLen() const {
|
||||
if (IsTiled()) {
|
||||
return cache->tiled_seq_len.value();
|
||||
}
|
||||
return kv_cache.Rows();
|
||||
}
|
||||
|
||||
inline bool KVCachePtr::IsTiled() const {
|
||||
// MPU code create a KVCachePtr without kv_cache.
|
||||
return cache != nullptr && cache->tiled_seq_len.has_value();
|
||||
}
|
||||
|
||||
// Convenience function to create views into KVCaches.
|
||||
std::vector<KVCachePtr> ToKVCachePtrs(const hwy::Span<KVCache>& kv_caches);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,660 @@
|
|||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// Note: HWY_DISABLED_TARGETS needs to be defined the same everywhere.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "util/basics.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/tiled_attention.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/attention.h"
|
||||
#include "gemma/flash_attention.h" // includes highway.h
|
||||
#include "gemma/gemma-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
static HWY_INLINE void MergeOnlineSoftmax(
|
||||
const float* HWY_RESTRICT other_att_out, const float other_softmax_max,
|
||||
const float other_softmax_d, int qkv_dim,
|
||||
float* HWY_RESTRICT accumulator_att_out, float& accumulator_softmax_max,
|
||||
float& accumulator_softmax_d) {
|
||||
if (other_softmax_d == 0.0f) {
|
||||
return;
|
||||
}
|
||||
if (accumulator_softmax_d == 0.0f) {
|
||||
memcpy(accumulator_att_out, other_att_out,
|
||||
qkv_dim * sizeof(*accumulator_att_out));
|
||||
accumulator_softmax_max = other_softmax_max;
|
||||
accumulator_softmax_d = other_softmax_d;
|
||||
return;
|
||||
}
|
||||
const float m_new = std::max(accumulator_softmax_max, other_softmax_max);
|
||||
const float exp_l = std::exp(accumulator_softmax_max - m_new);
|
||||
const float exp_r = std::exp(other_softmax_max - m_new);
|
||||
const float d_new = accumulator_softmax_d * exp_l + other_softmax_d * exp_r;
|
||||
const float d_new_inv = 1.0f / d_new;
|
||||
const float c1 = accumulator_softmax_d * exp_l * d_new_inv;
|
||||
const float c2 = other_softmax_d * exp_r * d_new_inv;
|
||||
MulByConst(c1, accumulator_att_out, qkv_dim);
|
||||
MulByConstAndAdd(c2, other_att_out, accumulator_att_out, qkv_dim);
|
||||
accumulator_softmax_max = m_new;
|
||||
accumulator_softmax_d = d_new;
|
||||
}
|
||||
|
||||
// Forked from ComputeQKV. But it stores the K/V in the tiled format
|
||||
// KV_T is type stored in the KV cache (typically float or BF16).
|
||||
template <typename KV_T>
|
||||
static HWY_INLINE void ComputeQKVTransposedTile(
|
||||
size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionImpl attention_impl, AttentionActivationsPtrs& activations,
|
||||
const QBatch& qbatch, const int flags, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKVTiled");
|
||||
const hwy::Divisor div_qbatch(qbatch.Size());
|
||||
const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kv_heads = layer_config.kv_heads;
|
||||
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
// This computes Q and stores it in activations.q.
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
// This computes Q and stores it in activations.q.
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w1,
|
||||
/*add=*/nullptr, env, activations.q);
|
||||
|
||||
// Compute the combined KV output from pre_att_rms_out.
|
||||
// The output shape is [num_interleaved, kv_heads * 2 * qkv_dim].
|
||||
const size_t kv_out_cols = kv_heads * 2 * qkv_dim;
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_out_mem =
|
||||
hwy::AllocateAligned<float>(num_interleaved * kv_out_cols);
|
||||
float* kv_out_data = kv_out_mem.get();
|
||||
MatPtrT<float> kv_out_mat("kv_out", Extents2D(num_interleaved, kv_out_cols));
|
||||
kv_out_mat.SetPtr(kv_out_data, kv_out_cols);
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
/*add=*/nullptr, env, kv_out_mat);
|
||||
|
||||
// Apply positional encodings and store K/V in tiled format.
|
||||
hwy::Divisor div_kv_heads(kv_heads);
|
||||
|
||||
hn::ScalableTag<float> df;
|
||||
static hwy::Divisor tile_size_divisor(KVCache::kTileSize);
|
||||
ParallelFor(
|
||||
Parallelism::kFlat, kv_heads * qbatch.Size(), env.ctx,
|
||||
/*cluster_idx=*/0, Callers::kAttComputeQKV,
|
||||
[&](size_t task, size_t worker) HWY_ATTR {
|
||||
const size_t kv_head = div_kv_heads.Remainder(task);
|
||||
const size_t query_idx = div_kv_heads.Divide(task);
|
||||
CompressPerThread tls;
|
||||
size_t current_token_idx = 0;
|
||||
float* k_tile_vec = activations.k_tile_vec.Row(task);
|
||||
float* v_tile_vec = activations.v_tile_vec.Row(task);
|
||||
HWY_ALIGN float k_f32[kMaxQKVDim];
|
||||
const size_t start_pos = qbatch.Pos(query_idx);
|
||||
const bool is_global_layer =
|
||||
activations.config.IsGlobalLayer(layer_idx);
|
||||
std::vector<MatPtr> kv_ptrs =
|
||||
qbatch.KV(query_idx).cache->GetPointers(
|
||||
layer_idx, kv_head, kv_heads, start_pos, is_global_layer);
|
||||
size_t tile_offset = 0;
|
||||
if (!is_global_layer) {
|
||||
tile_offset = start_pos / KVCache::kTileSize;
|
||||
}
|
||||
|
||||
while (current_token_idx < num_tokens) {
|
||||
const size_t pos = start_pos + current_token_idx;
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(pos);
|
||||
const size_t tile_idx = tile_size_divisor.Divide(pos_mod);
|
||||
const size_t relative_tile_idx = tile_idx - tile_offset;
|
||||
KV_T* tile_ptr;
|
||||
int kv_ptr_idx = 0;
|
||||
size_t absolute_rows = 0;
|
||||
while (absolute_rows + kv_ptrs[kv_ptr_idx].Rows() <=
|
||||
relative_tile_idx) {
|
||||
absolute_rows += kv_ptrs[kv_ptr_idx].Rows();
|
||||
kv_ptr_idx++;
|
||||
}
|
||||
tile_ptr = HWY_RCAST_ALIGNED(
|
||||
KV_T*,
|
||||
kv_ptrs[kv_ptr_idx].RowBytes(relative_tile_idx - absolute_rows));
|
||||
PackedSpan<KV_T> tile_packed_span{tile_ptr,
|
||||
2 * qkv_dim * KVCache::kTileSize};
|
||||
|
||||
DecompressAndZeroPad(df, tile_packed_span, 0, k_tile_vec,
|
||||
qkv_dim * KVCache::kTileSize);
|
||||
DecompressAndZeroPad(df, tile_packed_span,
|
||||
qkv_dim * KVCache::kTileSize, v_tile_vec,
|
||||
qkv_dim * KVCache::kTileSize);
|
||||
|
||||
size_t token_in_tile_idx = current_token_idx;
|
||||
while (token_in_tile_idx < num_tokens) {
|
||||
const size_t current_pos =
|
||||
qbatch.Pos(query_idx) + token_in_tile_idx;
|
||||
const size_t current_pos_mod =
|
||||
activations.div_seq_len.Remainder(current_pos);
|
||||
if (tile_size_divisor.Divide(current_pos_mod) != tile_idx) {
|
||||
break; // Moved to next tile
|
||||
}
|
||||
|
||||
const float* kv_row =
|
||||
kv_out_data +
|
||||
(token_in_tile_idx * qbatch.Size() + query_idx) * kv_out_cols;
|
||||
const float* k_ptr = kv_row + kv_head * 2 * qkv_dim;
|
||||
const float* v_ptr = kv_row + kv_head * 2 * qkv_dim + qkv_dim;
|
||||
hwy::CopyBytes(k_ptr, k_f32, qkv_dim * sizeof(float));
|
||||
if (layer.key_norm_scale.HasPtr()) {
|
||||
CallUpcasted(&layer.key_norm_scale, [&](const auto* weights_t) {
|
||||
RMSNormInplace(weights_t->PackedScale1(), /*w_ofs=*/0, k_f32,
|
||||
qkv_dim, env.ctx, worker);
|
||||
});
|
||||
}
|
||||
PositionalEncodingQK(
|
||||
k_f32, layer_idx, activations, env.ctx, worker,
|
||||
current_pos ,
|
||||
/*mul=*/1.0f);
|
||||
|
||||
const size_t in_tile_idx = current_pos_mod % KVCache::kTileSize;
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
const int in_tile_idx_mod_2 = in_tile_idx % 2;
|
||||
for (int dim = 0; dim < qkv_dim; dim += 2) {
|
||||
const int dim_mod_2 = dim % 2;
|
||||
// Pack k's in pairs in preparation for BF16 dot product.
|
||||
// See flash_attention.cc
|
||||
// QDotKTilexUpTo4TransposedKDoubleWidthBF16
|
||||
k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize +
|
||||
in_tile_idx * 2] = k_f32[dim];
|
||||
k_tile_vec[(dim - dim_mod_2) * KVCache::kTileSize +
|
||||
in_tile_idx * 2 + 1] = k_f32[dim + 1];
|
||||
// Pack v's in pairs
|
||||
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
|
||||
dim * 2 + in_tile_idx_mod_2] = v_ptr[dim];
|
||||
v_tile_vec[(in_tile_idx - in_tile_idx_mod_2) * qkv_dim +
|
||||
(dim + 1) * 2 + in_tile_idx_mod_2] = v_ptr[dim + 1];
|
||||
}
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < qkv_dim; ++i) {
|
||||
k_tile_vec[i * KVCache::kTileSize + in_tile_idx] = k_f32[i];
|
||||
}
|
||||
Compress(v_ptr, qkv_dim, tls, tile_packed_span,
|
||||
qkv_dim * (KVCache::kTileSize + in_tile_idx));
|
||||
}
|
||||
|
||||
token_in_tile_idx++;
|
||||
}
|
||||
Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls,
|
||||
tile_packed_span, 0);
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls,
|
||||
tile_packed_span, qkv_dim * KVCache::kTileSize);
|
||||
}
|
||||
current_token_idx = token_in_tile_idx;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: optimize with gathers
|
||||
// This format might change in the future, when kernel will be updated to
|
||||
// support more than 8 queries.
|
||||
// Input (num_queries, qkv_dim)
|
||||
// Output (qkv_dim, num_queries)
|
||||
void TransposeQ(const MatPtrT<float>& queries,
|
||||
hwy::Span<float> transposed_queries_span) {
|
||||
const size_t qkv_dim = queries.Cols();
|
||||
const size_t num_queries = queries.Rows();
|
||||
HWY_ASSERT(transposed_queries_span.size() == num_queries * qkv_dim);
|
||||
for (size_t i = 0; i < qkv_dim; i++) {
|
||||
for (size_t j = 0; j < num_queries; ++j) {
|
||||
transposed_queries_span[i * num_queries + j] = queries.Row(j)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transposes queries
|
||||
// Input: vector of pointers to subsequent queries. (allows for arbitrary
|
||||
// strides)
|
||||
// qkv_dim: dimension of query
|
||||
// allocator: aligned allocator to use for temporary storage
|
||||
//
|
||||
// Output: Pointer to contiguous memory with shape (qkv_dim,
|
||||
// queries.size())
|
||||
void TransposeStridedQueries(
|
||||
hwy::Span<float*> queries, int qkv_dim,
|
||||
hwy::Span<float> transposed_queries) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
using VF = hn::Vec<DF>;
|
||||
using DI = hn::ScalableTag<int32_t>;
|
||||
const DI di;
|
||||
using VI = hn::Vec<DI>;
|
||||
const size_t lanes = hn::Lanes(df);
|
||||
const size_t num_queries = queries.size();
|
||||
const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, lanes);
|
||||
std::vector<int32_t, hwy::AlignedAllocator<int32_t>> query_offsets(
|
||||
num_queries_rounded_up);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
query_offsets[i] = queries[i] - queries[0];
|
||||
}
|
||||
for (size_t i = num_queries; i < num_queries_rounded_up; ++i) {
|
||||
// last offset is the same so gather doesn't read out of bounds
|
||||
query_offsets[i] = query_offsets[num_queries - 1];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < qkv_dim; i++) {
|
||||
size_t j = 0;
|
||||
if (num_queries >= lanes) {
|
||||
for (; j <= num_queries-lanes; j += lanes) {
|
||||
const VI offsets = hn::LoadU(di, query_offsets.data() + j);
|
||||
VF x = hn::GatherIndex(df, queries[0] + i, offsets);
|
||||
hn::StoreU(x, df, transposed_queries.data() + i * num_queries + j);
|
||||
}
|
||||
}
|
||||
if (j < num_queries) {
|
||||
const VI offsets = hn::LoadU(di, query_offsets.data() + j);
|
||||
VF x = hn::GatherIndex(df, queries[0] + i, offsets);
|
||||
hn::StoreN(x, df, transposed_queries.data() + i * num_queries + j,
|
||||
num_queries - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<AlignedFloatVector, std::vector<float*>> TransposeQueriesToGroupsOf4(
|
||||
hwy::Span<float*> queries_ptrs, int qkv_dim) {
|
||||
int num_queries = queries_ptrs.size();
|
||||
int num_groups = hwy::DivCeil(num_queries, 4);
|
||||
AlignedFloatVector transposed_queries(num_groups * 4 * qkv_dim);
|
||||
std::vector<float*> transposed_queries_ptrs;
|
||||
for (int group_idx = 0; group_idx < num_groups; ++group_idx){
|
||||
int group_size = std::min(4, num_queries - group_idx * 4);
|
||||
transposed_queries_ptrs.push_back(transposed_queries.data() +
|
||||
group_idx * qkv_dim * 4);
|
||||
TransposeStridedQueries(
|
||||
hwy::Span<float*>(queries_ptrs.data() + group_idx * 4,
|
||||
group_size),
|
||||
qkv_dim,
|
||||
hwy::Span<float>(transposed_queries_ptrs.back(), qkv_dim * group_size));
|
||||
}
|
||||
return std::make_pair(std::move(transposed_queries),
|
||||
std::move(transposed_queries_ptrs));
|
||||
}
|
||||
|
||||
std::pair<AlignedBF16Vector, std::vector<BF16*>>
|
||||
TransposeTransposedQueriesAndPackIntoBF16(hwy::Span<float*> queries_ptrs,
|
||||
int qkv_dim, int num_queries) {
|
||||
constexpr int kMaxGroupSize = 4;
|
||||
int num_groups = queries_ptrs.size();
|
||||
AlignedBF16Vector transposed_queries(num_groups * kMaxGroupSize * qkv_dim);
|
||||
std::vector<BF16*> transposed_queries_ptrs;
|
||||
transposed_queries_ptrs.reserve(num_groups);
|
||||
for (int group_idx = 0; group_idx < num_groups; ++group_idx) {
|
||||
int group_size =
|
||||
std::min(kMaxGroupSize, num_queries - group_idx * kMaxGroupSize);
|
||||
transposed_queries_ptrs.push_back(transposed_queries.data() +
|
||||
group_idx * qkv_dim * kMaxGroupSize);
|
||||
for (int dim_idx = 0; dim_idx < qkv_dim; dim_idx += 2) {
|
||||
for (int query_idx = 0; query_idx < group_size; ++query_idx) {
|
||||
transposed_queries_ptrs.back()[dim_idx * group_size + query_idx * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
queries_ptrs[group_idx][dim_idx * group_size + query_idx]);
|
||||
transposed_queries_ptrs
|
||||
.back()[dim_idx * group_size + query_idx * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
queries_ptrs[group_idx]
|
||||
[(dim_idx + 1) * group_size + query_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(std::move(transposed_queries),
|
||||
std::move(transposed_queries_ptrs));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static HWY_INLINE void MaybeResizeMatStorage(MatStorageT<T>& mat_storage,
|
||||
int rows, int cols,
|
||||
const char* name,
|
||||
const Allocator& allocator) {
|
||||
if (mat_storage.Rows() != rows || mat_storage.Cols() != cols) {
|
||||
mat_storage = MatStorageT<T>(name, Extents2D(rows, cols), allocator,
|
||||
MatPadding::kOdd);
|
||||
}
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
// Schedules TiledFlashAttention for all heads, tokens and batch.
|
||||
// Returns partial results in the same order as queries in `activations.q`.
|
||||
// Might not work yet for prefix lm.
|
||||
// To help understanding how to use this function below is description of how
|
||||
// parameters are used:
|
||||
//
|
||||
// attention_impl - Used to determine attention kernel to use.
|
||||
// num_query_tokens - number of tokens/timesteps in processed in a single batch
|
||||
// it will influence how many queries kvs are evaluated against.
|
||||
// num_kv_tokens - number of tokens/timesteps in kv cache
|
||||
// layer_idx - layer index
|
||||
// layer - used to get kv_heads, heads, qkv_dim
|
||||
// activations - reads: activations.q queries, att_cap, IsGlobalLayer
|
||||
// qbatch - kv cache, Pos / EndPrefix
|
||||
// ctx - threading context
|
||||
// clang-format on
|
||||
void LocalAttentionForAllHeadsTokensAndBatch(
|
||||
AttentionImpl attention_impl, const size_t num_query_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
ThreadingContext& ctx) {
|
||||
const size_t heads_per_kv_head =
|
||||
layer.layer_config.heads / layer.layer_config.kv_heads;
|
||||
|
||||
int core_count = ctx.pools.MaxWorkers();
|
||||
int task_multiplier = 1;
|
||||
while (qbatch.Size() * layer.layer_config.kv_heads * task_multiplier <
|
||||
core_count * 2) {
|
||||
task_multiplier++;
|
||||
}
|
||||
// Finding the smallest context we need to attend to avoid unnecessary
|
||||
// overhead when sub-splitting doesn't make sense. This check overestimates
|
||||
// context sizes because it ignores [local] layer sizes and explicit
|
||||
// qbatch.Prefix settings.
|
||||
size_t min_pos = qbatch.Pos(0);
|
||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||
min_pos = std::min(min_pos, qbatch.Pos(qi));
|
||||
}
|
||||
if (min_pos / task_multiplier < num_query_tokens) {
|
||||
// In case where min_pos / task_multiplier < num_tokens
|
||||
// To make sure we don't over count tokens or read out of bounds code
|
||||
// requires quite a bit more involved logic.
|
||||
// Also there is not much point to splitting the work into more tasks, when
|
||||
// amount of work is small.
|
||||
task_multiplier = 1;
|
||||
}
|
||||
[[maybe_unused]] int num_tasks = qbatch.Size() * layer.layer_config.kv_heads;
|
||||
[[maybe_unused]] int num_sub_tasks =
|
||||
qbatch.Size() * layer.layer_config.kv_heads * task_multiplier;
|
||||
HWY_DASSERT_M(activations.q.Rows() == num_query_tokens * qbatch.Size(),
|
||||
"qbatch size mismatch");
|
||||
int qkv_dim = layer.layer_config.qkv_dim;
|
||||
|
||||
// sizes of all should be in sync
|
||||
if (num_sub_tasks > activations.sub_task_att_out->size()) {
|
||||
activations.sub_task_att_out->resize(num_sub_tasks);
|
||||
activations.sub_task_exp_denominator_sums->resize(num_sub_tasks);
|
||||
activations.sub_task_max_logits->resize(num_sub_tasks);
|
||||
}
|
||||
std::vector<int> skip_sub_task(num_sub_tasks, 0);
|
||||
|
||||
// This loop parallelizes over qbatch, kv_head and substrings of context
|
||||
// tokens. Each parallel invocation handles all query tokens of the given
|
||||
// qbatch.
|
||||
ParallelFor(
|
||||
Parallelism::kHierarchical, num_sub_tasks, ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t task_idx, size_t worker) HWY_ATTR {
|
||||
size_t main_task_idx = task_idx / task_multiplier;
|
||||
size_t sub_task_idx = task_idx % task_multiplier;
|
||||
size_t current_qbatch_idx =
|
||||
main_task_idx / layer.layer_config.kv_heads;
|
||||
size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads;
|
||||
// First and last context token we will attend to.
|
||||
size_t global_start_context_pos = StartPos(
|
||||
qbatch.Pos(current_qbatch_idx), activations.config, layer_idx);
|
||||
// Keep in mind this is overestimation because some timesteps might not
|
||||
// need all tokens due to causal mask.
|
||||
// We will use it to determine how to divide work between sub tasks
|
||||
// and make sure PrefixEnd is taken into account
|
||||
size_t start_context_pos = global_start_context_pos;
|
||||
size_t last_context_pos =
|
||||
qbatch.Pos(current_qbatch_idx) + num_query_tokens - 1;
|
||||
// In some models, context is limited to some prefix - make sure we take
|
||||
// that into account.
|
||||
const size_t prefix_end = qbatch.PrefixEnd(current_qbatch_idx);
|
||||
if (prefix_end > 0 && prefix_end - 1 > last_context_pos) {
|
||||
last_context_pos = prefix_end - 1;
|
||||
}
|
||||
size_t total_num_context_tokens =
|
||||
last_context_pos - start_context_pos + 1;
|
||||
size_t context_tokens_per_sub_task =
|
||||
hwy::DivCeil(total_num_context_tokens, task_multiplier);
|
||||
// Restrict tokens to attend to the substring of context tokens that
|
||||
// this subtask is responsible for.
|
||||
start_context_pos =
|
||||
start_context_pos + context_tokens_per_sub_task * sub_task_idx;
|
||||
if (start_context_pos > last_context_pos) {
|
||||
skip_sub_task[task_idx] = 1;
|
||||
return;
|
||||
}
|
||||
last_context_pos =
|
||||
std::min(last_context_pos,
|
||||
start_context_pos + context_tokens_per_sub_task - 1);
|
||||
// pre-initialize memory [to avoid racy resizes laters].
|
||||
int num_queries = num_query_tokens * heads_per_kv_head;
|
||||
std::vector<float*> queries_ptrs;
|
||||
queries_ptrs.reserve(num_queries);
|
||||
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
|
||||
for (int q_head_idx = 0; q_head_idx < heads_per_kv_head;
|
||||
++q_head_idx) {
|
||||
queries_ptrs.push_back(
|
||||
activations.q.Row(token_idx * qbatch.Size() +
|
||||
current_qbatch_idx) +
|
||||
(kv_head_idx * heads_per_kv_head + q_head_idx) * qkv_dim);
|
||||
}
|
||||
}
|
||||
hwy::Span<float*> queries_ptrs_span(queries_ptrs.data(),
|
||||
queries_ptrs.size());
|
||||
|
||||
auto [transposed_queries, transposed_queries_ptrs] =
|
||||
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
|
||||
|
||||
MatStorageT<float>& att_out =
|
||||
activations.sub_task_att_out->at(task_idx);
|
||||
AlignedFloatVector& exp_denominator_sums =
|
||||
activations.sub_task_exp_denominator_sums->at(task_idx);
|
||||
AlignedFloatVector& max_logits =
|
||||
activations.sub_task_max_logits->at(task_idx);
|
||||
MaybeResizeMatStorage(att_out, num_queries, qkv_dim, "att_out",
|
||||
ctx.allocator);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(i),
|
||||
att_out.Cols() * sizeof(decltype(att_out.Row(i)[0])));
|
||||
}
|
||||
|
||||
int num_queries_rounded_to_8 = hwy::RoundUpTo(num_queries, 8);
|
||||
exp_denominator_sums.resize(num_queries_rounded_to_8);
|
||||
max_logits.resize(num_queries_rounded_to_8);
|
||||
for (int i = 0; i < num_queries_rounded_to_8; ++i) {
|
||||
exp_denominator_sums[i] = 0.0f;
|
||||
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
|
||||
}
|
||||
// Get pointers to the KVCache tiles, starting at global_start_pos
|
||||
// Returns multiple matrices for non-contiguous memory, for example as a
|
||||
// result of the wraparound in local layers.
|
||||
std::vector<MatPtr> kv_ptrs =
|
||||
qbatch.KV(current_qbatch_idx)
|
||||
.cache->GetPointers(
|
||||
layer_idx, kv_head_idx, layer.layer_config.kv_heads,
|
||||
global_start_context_pos,
|
||||
activations.config.IsGlobalLayer(layer_idx));
|
||||
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
|
||||
start_pos_per_query.reserve(num_queries);
|
||||
last_pos_per_query.reserve(num_queries);
|
||||
// Position of the first token in the first tile whose pointer was
|
||||
// returned above. Allows for handling of token positions relative to
|
||||
// the KV tiles returned above.
|
||||
size_t rounded_down_global_start_pos =
|
||||
hwy::RoundDownTo(global_start_context_pos, KVCache::kTileSize);
|
||||
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
|
||||
int64_t global_query_pos =
|
||||
qbatch.Pos(current_qbatch_idx) + token_idx;
|
||||
// Intersect context to attend to for this specific query token
|
||||
// to the context tokens of the current subtask.
|
||||
int64_t query_last_context_pos = std::min(
|
||||
static_cast<int64_t>(last_context_pos), global_query_pos);
|
||||
// This max is to not go into negative values, for the same reason we
|
||||
// use int64_t and not size_t here.
|
||||
int64_t query_start_context_pos = std::max(
|
||||
global_query_pos -
|
||||
static_cast<int64_t>(
|
||||
activations.config.attention_window_sizes[layer_idx]) +
|
||||
1,
|
||||
static_cast<int64_t>(start_context_pos));
|
||||
|
||||
// Turn token position into KV-tile relative token positions.
|
||||
query_last_context_pos -= rounded_down_global_start_pos;
|
||||
query_start_context_pos -= rounded_down_global_start_pos;
|
||||
for (int q_head_idx = 0; q_head_idx < heads_per_kv_head;
|
||||
++q_head_idx) {
|
||||
start_pos_per_query.push_back(query_start_context_pos);
|
||||
last_pos_per_query.push_back(query_last_context_pos);
|
||||
}
|
||||
}
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
// pack transposed queries into BF16
|
||||
hwy::Span<float*> queries_span(transposed_queries_ptrs.data(),
|
||||
transposed_queries_ptrs.size());
|
||||
auto [_, transposed_queries_ptrs_bf16] =
|
||||
TransposeTransposedQueriesAndPackIntoBF16(queries_span, qkv_dim,
|
||||
num_queries);
|
||||
hwy::Span<const BF16*> queries_span_bf16(
|
||||
const_cast<const BF16**>(transposed_queries_ptrs_bf16.data()),
|
||||
transposed_queries_ptrs_bf16.size());
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
||||
kv_ptrs, num_queries, queries_span_bf16,
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query),
|
||||
activations.config.att_cap, att_out, exp_denominator_sums.data(),
|
||||
max_logits.data());
|
||||
} else {
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kv_ptrs, num_queries,
|
||||
hwy::Span<const float*>(
|
||||
const_cast<const float**>(transposed_queries_ptrs.data()),
|
||||
transposed_queries_ptrs.size()),
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query),
|
||||
activations.config.att_cap, att_out, exp_denominator_sums.data(),
|
||||
max_logits.data());
|
||||
}
|
||||
});
|
||||
|
||||
// This loop takes results from separate subtasks (subsequence of kv) and
|
||||
// merges them into single att_out over whole kv sequence.
|
||||
ParallelFor(
|
||||
Parallelism::kFlat, num_tasks, ctx,
|
||||
/*cluster_idx=*/0, Callers::kFlashAttention,
|
||||
[&](size_t main_task_idx, size_t worker) HWY_ATTR {
|
||||
size_t current_qbatch_idx = main_task_idx / layer.layer_config.kv_heads;
|
||||
size_t kv_head_idx = main_task_idx % layer.layer_config.kv_heads;
|
||||
for (int token_idx = 0; token_idx < num_query_tokens; ++token_idx) {
|
||||
for (int head_in_group_idx = 0; head_in_group_idx < heads_per_kv_head;
|
||||
++head_in_group_idx) {
|
||||
const size_t batch_index =
|
||||
current_qbatch_idx * num_query_tokens + token_idx;
|
||||
const size_t q_head_idx =
|
||||
kv_head_idx * heads_per_kv_head + head_in_group_idx;
|
||||
const size_t att_out_row_idx =
|
||||
token_idx * heads_per_kv_head + head_in_group_idx;
|
||||
const size_t activations_att_out_start_idx = q_head_idx * qkv_dim;
|
||||
auto& att_out_0 = activations.sub_task_att_out->at(
|
||||
main_task_idx * task_multiplier + 0);
|
||||
auto& exp_denominator_sums_0 =
|
||||
activations.sub_task_exp_denominator_sums->at(
|
||||
main_task_idx * task_multiplier + 0);
|
||||
auto& max_logits_0 = activations.sub_task_max_logits->at(
|
||||
main_task_idx * task_multiplier + 0);
|
||||
|
||||
hwy::CopyBytes(att_out_0.Row(att_out_row_idx),
|
||||
activations.att_out.Row(batch_index) +
|
||||
activations_att_out_start_idx,
|
||||
qkv_dim * sizeof(float));
|
||||
activations.softmax_d.Row(batch_index)[q_head_idx] =
|
||||
exp_denominator_sums_0[token_idx * heads_per_kv_head +
|
||||
head_in_group_idx];
|
||||
activations.softmax_max.Row(batch_index)[q_head_idx] =
|
||||
max_logits_0[token_idx * heads_per_kv_head + head_in_group_idx];
|
||||
for (int sub_task_idx = 1; sub_task_idx < task_multiplier;
|
||||
++sub_task_idx) {
|
||||
int task_idx = main_task_idx * task_multiplier + sub_task_idx;
|
||||
if (skip_sub_task[task_idx] == 1) {
|
||||
continue;
|
||||
}
|
||||
auto& att_out = activations.sub_task_att_out->at(task_idx);
|
||||
auto& exp_denominator_sums =
|
||||
activations.sub_task_exp_denominator_sums->at(task_idx);
|
||||
auto& max_logits = activations.sub_task_max_logits->at(task_idx);
|
||||
MergeOnlineSoftmax(
|
||||
att_out.Row(att_out_row_idx),
|
||||
max_logits[token_idx * heads_per_kv_head + head_in_group_idx],
|
||||
exp_denominator_sums[token_idx * heads_per_kv_head +
|
||||
head_in_group_idx],
|
||||
qkv_dim,
|
||||
activations.att_out.Row(batch_index) +
|
||||
activations_att_out_start_idx,
|
||||
activations.softmax_max.Row(batch_index)[q_head_idx],
|
||||
activations.softmax_d.Row(batch_index)[q_head_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch,
|
||||
MatMulEnv& env, int flags) {
|
||||
static const auto zone = env.ctx.profiler.AddZone(
|
||||
"Gen.TiledAttention", hwy::ProfilerFlags::kInclusive);
|
||||
PROFILER_ZONE3(env.ctx.profiler, hwy::Profiler::Thread(), zone);
|
||||
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
(void)layer_config; // only used in HWY_DASSERT
|
||||
if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kBF16) {
|
||||
ComputeQKVTransposedTile<BF16>(num_tokens, layer_idx, layer, attention_impl,
|
||||
activations, qbatch, flags, env);
|
||||
} else {
|
||||
ComputeQKVTransposedTile<KV_t>(num_tokens, layer_idx, layer, attention_impl,
|
||||
activations, qbatch, flags, env);
|
||||
}
|
||||
RMSNormAndPositionalEncoding(num_tokens, qbatch, activations.q,
|
||||
layer.query_norm_scale, layer_idx, activations,
|
||||
env.ctx);
|
||||
LocalAttentionForAllHeadsTokensAndBatch(attention_impl, num_tokens, layer_idx,
|
||||
layer, activations, qbatch, env.ctx);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Passed to HWY_VISIT_TARGETS; declares for one target.
|
||||
#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \
|
||||
namespace NAMESPACE { \
|
||||
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
MatMulEnv& env, int flags); \
|
||||
void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \
|
||||
hwy::Span<float> transposed_queries); \
|
||||
void LocalAttentionForAllHeadsTokensAndBatch( \
|
||||
AttentionImpl attention_impl, const size_t num_tokens, \
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
// Function declarations for each SIMD target. Allows direct call from the
|
||||
// per-target namespace. We may later replace this with dynamic dispatch if
|
||||
// the overhead is acceptable.
|
||||
HWY_VISIT_TARGETS(GEMMA_DECL_TILED_ATTENTION)
|
||||
|
||||
#undef GEMMA_DECL_TILED_ATTENTION
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TILED_ATTENTION_H_
|
||||
|
|
@ -0,0 +1,749 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/tiled_attention_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/tiled_attention.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
struct AttentionTestEnv {
|
||||
AttentionTestEnv(
|
||||
int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads,
|
||||
int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx,
|
||||
int layers_total, int qbatch_size, AttentionImpl attention_impl,
|
||||
)
|
||||
: ctx(threading_args), env(ctx) {
|
||||
layer_config.heads = num_heads;
|
||||
layer_config.kv_heads = num_kv_heads;
|
||||
layer_config.qkv_dim = qkv_dim;
|
||||
layer_config.model_dim = qkv_dim * num_heads;
|
||||
|
||||
model_config.attention_window_sizes = {
|
||||
static_cast<uint32_t>(attention_window_size)};
|
||||
model_config.att_cap = att_cap;
|
||||
model_config.max_seq_len = kv_seq_len;
|
||||
model_config.num_layers = layers_total;
|
||||
model_config.model_dim = layer_config.model_dim;
|
||||
model_config.vocab_size = 1; // not vit
|
||||
|
||||
for (int i = 0; i < model_config.num_layers; ++i) {
|
||||
model_config.layer_configs.push_back(layer_config);
|
||||
}
|
||||
tensor_info_registry = std::make_unique<TensorInfoRegistry>(model_config);
|
||||
layer = std::make_unique<LayerWeightsPtrs>(layer_idx, layer_config,
|
||||
*tensor_info_registry);
|
||||
|
||||
runtime_config.attention_impl = attention_impl;
|
||||
inference_args.seq_len = kv_seq_len;
|
||||
|
||||
all_queries.Reserve(qbatch_size);
|
||||
kv_caches.reserve(qbatch_size);
|
||||
for (int q = 0; q < qbatch_size; ++q) {
|
||||
kv_caches.emplace_back(model_config, inference_args, runtime_config,
|
||||
ctx.allocator);
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
MatPtrT<BF16> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
|
||||
for (int i = 0; i < compact_kv_cache.Rows(); ++i) {
|
||||
for (int j = 0; j < compact_kv_cache.Cols(); ++j) {
|
||||
BF16 val = hwy::ConvertScalarTo<BF16>(hwy::Unpredictable1() *
|
||||
0.01f * (i + j + 1));
|
||||
// split j into if k/v
|
||||
if (j < qkv_dim * gcpp::KVCache::kTileSize) {
|
||||
// split j into dim and in tile offset
|
||||
const int dim = j / gcpp::KVCache::kTileSize;
|
||||
const int in_tile_offset = j % gcpp::KVCache::kTileSize;
|
||||
const int dim_mod_2 = dim % 2;
|
||||
compact_kv_cache.Row(
|
||||
i)[(dim - dim_mod_2) * gcpp::KVCache::kTileSize +
|
||||
in_tile_offset * 2 + dim_mod_2] = val;
|
||||
} else {
|
||||
const int in_tile_offset = j / qkv_dim;
|
||||
const int dim = j % qkv_dim;
|
||||
const int in_tile_offset_mod_2 = in_tile_offset % 2;
|
||||
compact_kv_cache.Row(
|
||||
i)[(in_tile_offset - in_tile_offset_mod_2) * qkv_dim +
|
||||
dim * 2 + in_tile_offset_mod_2] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
|
||||
MatPtrT<float> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
|
||||
FillMatPtrT(compact_kv_cache);
|
||||
} else {
|
||||
FillMatPtrT(kv_caches.back().kv_cache);
|
||||
}
|
||||
all_queries.Append({
|
||||
.prompt = PromptTokens({1, 2, 3}),
|
||||
.mutable_pos = static_cast<size_t>(last_pos),
|
||||
.initial_pos = 0,
|
||||
.prefix_end = 0,
|
||||
.kv_cache = kv_caches.back().ToPtr(),
|
||||
});
|
||||
}
|
||||
|
||||
activations = std::make_unique<Activations>(runtime_config, model_config,
|
||||
qbatch_size * num_tokens,
|
||||
kv_seq_len, ctx, env.row_ptrs);
|
||||
|
||||
qbatch =
|
||||
std::make_unique<QBatch>(/*start_pos=*/0, qbatch_size, all_queries);
|
||||
}
|
||||
|
||||
void SetupWeights() {
|
||||
int model_dim = layer_config.model_dim;
|
||||
int qkv_dim = layer_config.qkv_dim;
|
||||
int num_heads = layer_config.heads;
|
||||
int num_kv_heads = layer_config.kv_heads;
|
||||
|
||||
qkv1_w_storage =
|
||||
MatStorageT<float>("qkv1", Extents2D(model_dim, qkv_dim * num_heads),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
qkv2_w_storage = MatStorageT<float>(
|
||||
"qkv2", Extents2D(model_dim, num_kv_heads * 2 * qkv_dim), ctx.allocator,
|
||||
MatPadding::kPacked);
|
||||
wo_w_storage = MatStorageT<float>("wo", Extents2D(model_dim, model_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
|
||||
FillMatPtrT(wo_w_storage);
|
||||
layer->att_weights = wo_w_storage;
|
||||
FillMatPtrT(qkv1_w_storage);
|
||||
FillMatPtrT(qkv2_w_storage);
|
||||
layer->qkv_einsum_w1 = qkv1_w_storage;
|
||||
layer->qkv_einsum_w2 = qkv2_w_storage;
|
||||
|
||||
query_norm_scale = MatStorageT<float>("query_norm", qkv_dim, ctx.allocator);
|
||||
FillMatPtrT(query_norm_scale);
|
||||
layer->query_norm_scale = query_norm_scale;
|
||||
|
||||
key_norm_scale = MatStorageT<float>("key_norm", qkv_dim, ctx.allocator);
|
||||
FillMatPtrT(key_norm_scale);
|
||||
layer->key_norm_scale = key_norm_scale;
|
||||
}
|
||||
|
||||
AttentionTestEnv(const AttentionTestEnv&) = delete;
|
||||
AttentionTestEnv& operator=(const AttentionTestEnv&) = delete;
|
||||
AttentionTestEnv(AttentionTestEnv&&) = delete;
|
||||
AttentionTestEnv& operator=(AttentionTestEnv&&) = delete;
|
||||
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx;
|
||||
MatMulEnv env;
|
||||
LayerConfig layer_config;
|
||||
ModelConfig model_config;
|
||||
std::unique_ptr<TensorInfoRegistry> tensor_info_registry;
|
||||
std::unique_ptr<LayerWeightsPtrs> layer;
|
||||
RuntimeConfig runtime_config;
|
||||
InferenceArgs inference_args;
|
||||
AllQueries all_queries;
|
||||
std::vector<KVCache> kv_caches;
|
||||
std::unique_ptr<Activations> activations;
|
||||
std::unique_ptr<QBatch> qbatch;
|
||||
|
||||
// Weights storage for later tests
|
||||
MatStorageT<float> qkv1_w_storage;
|
||||
MatStorageT<float> qkv2_w_storage;
|
||||
MatStorageT<float> wo_w_storage;
|
||||
MatStorageT<float> query_norm_scale;
|
||||
MatStorageT<float> key_norm_scale;
|
||||
};
|
||||
|
||||
void TestTransposeStridedQueries() {
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
int qkv_dim = 64;
|
||||
int num_queries = 24;
|
||||
AlignedPtr<float[]> input_queries =
|
||||
ctx.allocator.Alloc<float>(qkv_dim * num_queries);
|
||||
AlignedPtr<float[]> output_queries =
|
||||
ctx.allocator.Alloc<float>(qkv_dim * num_queries);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
input_queries[i * qkv_dim + j] = i * qkv_dim + j;
|
||||
}
|
||||
}
|
||||
std::vector<float*> queries;
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
queries.push_back(input_queries.get() + i * qkv_dim);
|
||||
}
|
||||
hwy::Span<float*> queries_span(queries.data(), queries.size());
|
||||
|
||||
TransposeStridedQueries(
|
||||
queries_span, qkv_dim,
|
||||
hwy::Span<float>(output_queries.get(), qkv_dim * num_queries));
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_EQ(output_queries[j * num_queries + i],
|
||||
input_queries[i * qkv_dim + j])
|
||||
<< "i=" << i << " j=" << j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestLocalAttentionForAllHeadsTokensAndBatch() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 64;
|
||||
int num_kv_heads = 2;
|
||||
int num_heads = 2;
|
||||
int num_tokens = 2;
|
||||
int last_pos = 62; // so token 0 will have 63 and token 1 will have 64 tokens
|
||||
// to attend to.
|
||||
float att_cap = 10.0f;
|
||||
int layer_idx = 0;
|
||||
int layers_total = 1;
|
||||
int qbatch_size = 2;
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
|
||||
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
|
||||
num_heads, num_tokens, last_pos, att_cap, layer_idx,
|
||||
layers_total, qbatch_size, attention_impl);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
LocalAttentionForAllHeadsTokensAndBatch(
|
||||
attention_impl, num_tokens, layer_idx, *test_env.layer,
|
||||
test_env.activations->attention, *test_env.qbatch, test_env.ctx);
|
||||
|
||||
// print states;
|
||||
std::vector<float> exp_denominator_sums_gold = {63, 63, 64, 64,
|
||||
63, 63, 64, 64};
|
||||
std::vector<float> max_logits_gold = {10, 10, 10, 10, 10, 10, 10, 10};
|
||||
std::vector<float> att_out_gold = {
|
||||
30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275,
|
||||
30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075,
|
||||
30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875,
|
||||
30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675,
|
||||
30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475,
|
||||
30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275,
|
||||
30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075,
|
||||
30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875,
|
||||
30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475,
|
||||
30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275,
|
||||
30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075,
|
||||
30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875,
|
||||
30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675,
|
||||
30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475,
|
||||
30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275,
|
||||
30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075,
|
||||
30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485,
|
||||
30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565,
|
||||
30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645,
|
||||
30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725,
|
||||
30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805,
|
||||
30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885,
|
||||
30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965,
|
||||
30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045,
|
||||
30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505,
|
||||
30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585,
|
||||
30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665,
|
||||
30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745,
|
||||
30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825,
|
||||
30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905,
|
||||
30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985,
|
||||
30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065,
|
||||
30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275,
|
||||
30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075,
|
||||
30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875,
|
||||
30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675,
|
||||
30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475,
|
||||
30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275,
|
||||
30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075,
|
||||
30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875,
|
||||
30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475,
|
||||
30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275,
|
||||
30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075,
|
||||
30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875,
|
||||
30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675,
|
||||
30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475,
|
||||
30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275,
|
||||
30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075,
|
||||
30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485,
|
||||
30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565,
|
||||
30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645,
|
||||
30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725,
|
||||
30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805,
|
||||
30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885,
|
||||
30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965,
|
||||
30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045,
|
||||
30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505,
|
||||
30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585,
|
||||
30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665,
|
||||
30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745,
|
||||
30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825,
|
||||
30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905,
|
||||
30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985,
|
||||
30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065,
|
||||
};
|
||||
const int group_size = num_heads / num_kv_heads;
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int q_batch_idx = 0; q_batch_idx < qbatch_size; ++q_batch_idx) {
|
||||
int b = token_idx * qbatch_size + q_batch_idx;
|
||||
EXPECT_THAT(
|
||||
absl::MakeSpan(test_env.activations->attention.softmax_d.Row(b),
|
||||
num_heads),
|
||||
Pointwise(FloatNear(1e-3f), absl::MakeSpan(exp_denominator_sums_gold)
|
||||
.subspan(b * num_heads, num_heads)));
|
||||
EXPECT_THAT(
|
||||
absl::MakeSpan(test_env.activations->attention.softmax_max.Row(b),
|
||||
num_heads),
|
||||
Pointwise(FloatNear(1e-3f), absl::MakeSpan(max_logits_gold)
|
||||
.subspan(b * num_heads, num_heads)));
|
||||
for (int kv_h = 0; kv_h < num_kv_heads; ++kv_h) {
|
||||
for (int g = 0; g < group_size; ++g) {
|
||||
const int q_h = kv_h * group_size + g;
|
||||
size_t expected_q_idx = b * num_heads + q_h;
|
||||
EXPECT_THAT(
|
||||
absl::MakeSpan(test_env.activations->attention.att_out.Row(b) +
|
||||
q_h * qkv_dim,
|
||||
qkv_dim),
|
||||
Pointwise(FloatNear(1e-3f),
|
||||
absl::MakeSpan(att_out_gold)
|
||||
.subspan(expected_q_idx * qkv_dim, qkv_dim)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<float> AttentionMultipleTokensAttentionGoldens = {
|
||||
34.7414, 34.7717, 34.8022, 34.8327, 34.8631, 34.8936, 34.9241, 34.9545,
|
||||
34.985, 35.0156, 35.046, 35.0765, 35.1068, 35.1373, 35.1678, 35.1982,
|
||||
35.2286, 35.2592, 35.2895, 35.32, 35.3506, 35.381, 35.4115, 35.4421,
|
||||
35.4725, 35.503, 35.5334, 35.5638, 35.5943, 35.6247, 35.6552, 35.6857,
|
||||
35.7161, 35.7466, 35.7772, 35.8076, 35.8381, 35.8685, 35.8989, 35.9294,
|
||||
35.9598, 35.9902, 36.0208, 36.0512, 36.0816, 36.1122, 36.1426, 36.1731,
|
||||
36.2037, 36.2341, 36.2646, 36.295, 36.3254, 36.356, 36.3863, 36.4168,
|
||||
36.4474, 36.4778, 36.5082, 36.5388, 36.5692, 36.5997, 36.6301, 36.6605,
|
||||
34.6687, 34.6987, 34.7288, 34.759, 34.7891, 34.8192, 34.8495, 34.8795,
|
||||
34.9097, 34.9399, 34.97, 35.0002, 35.0302, 35.0604, 35.0906, 35.1206,
|
||||
35.1507, 35.181, 35.211, 35.2412, 35.2714, 35.3015, 35.3317, 35.3619,
|
||||
35.3921, 35.4222, 35.4523, 35.4824, 35.5126, 35.5427, 35.5728, 35.603,
|
||||
35.6331, 35.6633, 35.6935, 35.7236, 35.7538, 35.7838, 35.814, 35.8442,
|
||||
35.8742, 35.9043, 35.9346, 35.9646, 35.9948, 36.025, 36.0551, 36.0853,
|
||||
36.1155, 36.1456, 36.1759, 36.2059, 36.236, 36.2662, 36.2963, 36.3264,
|
||||
36.3566, 36.3867, 36.4169, 36.4471, 36.4772, 36.5074, 36.5374, 36.5676,
|
||||
37.0338, 37.0634, 37.0929, 37.1222, 37.1519, 37.1813, 37.2107, 37.2403,
|
||||
37.2698, 37.2992, 37.3288, 37.3584, 37.3877, 37.4174, 37.447, 37.4764,
|
||||
37.5056, 37.5352, 37.5646, 37.5938, 37.6234, 37.6528, 37.6821, 37.7117,
|
||||
37.7412, 37.7705, 37.8001, 37.8295, 37.8589, 37.8885, 37.918, 37.9473,
|
||||
37.977, 38.0065, 38.0358, 38.0655, 38.095, 38.1244, 38.1541, 38.1836,
|
||||
38.213, 38.2422, 38.2718, 38.3012, 38.3305, 38.36, 38.3895, 38.4187,
|
||||
38.4484, 38.4778, 38.5071, 38.5367, 38.5662, 38.5955, 38.6251, 38.6546,
|
||||
38.6839, 38.7136, 38.7431, 38.7725, 38.8021, 38.8316, 38.861, 38.8907,
|
||||
36.9872, 37.0167, 37.046, 37.0752, 37.1047, 37.1341, 37.1633, 37.1928,
|
||||
37.2222, 37.2514, 37.2809, 37.3103, 37.3396, 37.3691, 37.3985, 37.4278,
|
||||
37.4569, 37.4863, 37.5156, 37.5447, 37.5742, 37.6035, 37.6326, 37.6621,
|
||||
37.6914, 37.7206, 37.7501, 37.7794, 37.8086, 37.8381, 37.8674, 37.8966,
|
||||
37.9262, 37.9555, 37.9848, 38.0143, 38.0437, 38.0729, 38.1025, 38.1319,
|
||||
38.1612, 38.1903, 38.2197, 38.249, 38.2781, 38.3075, 38.3368, 38.366,
|
||||
38.3955, 38.4248, 38.4539, 38.4834, 38.5127, 38.5419, 38.5714, 38.6008,
|
||||
38.63, 38.6595, 38.6889, 38.7181, 38.7477, 38.777, 38.8063, 38.8358,
|
||||
39.0984, 39.1479, 39.1976, 39.2475, 39.297, 39.3468, 39.3967, 39.4463,
|
||||
39.4961, 39.546, 39.5957, 39.6455, 39.695, 39.7447, 39.7946, 39.8441,
|
||||
39.8939, 39.9438, 39.9934, 40.0431, 40.0931, 40.1427, 40.1925, 40.2425,
|
||||
40.2921, 40.342, 40.3915, 40.4412, 40.4911, 40.5407, 40.5904, 40.6403,
|
||||
40.6899, 40.7397, 40.7897, 40.8393, 40.8892, 40.9387, 40.9884, 41.0382,
|
||||
41.0878, 41.1375, 41.1874, 41.237, 41.2868, 41.3367, 41.3863, 41.4361,
|
||||
41.4861, 41.5358, 41.5856, 41.6351, 41.6849, 41.7347, 41.7843, 41.834,
|
||||
41.884, 41.9336, 41.9834, 42.0333, 42.083, 42.1328, 42.1823, 42.232,
|
||||
38.9699, 39.0188, 39.068, 39.1173, 39.1663, 39.2155, 39.2648, 39.3138,
|
||||
39.3631, 39.4124, 39.4615, 39.5108, 39.5597, 39.6089, 39.6581, 39.7071,
|
||||
39.7563, 39.8056, 39.8546, 39.9039, 39.9532, 40.0023, 40.0515, 40.1009,
|
||||
40.15, 40.1993, 40.2483, 40.2974, 40.3467, 40.3957, 40.4449, 40.4942,
|
||||
40.5433, 40.5925, 40.6419, 40.691, 40.7402, 40.7892, 40.8383, 40.8876,
|
||||
40.9366, 40.9857, 41.035, 41.0841, 41.1333, 41.1826, 41.2317, 41.2809,
|
||||
41.3303, 41.3794, 41.4287, 41.4777, 41.5268, 41.5761, 41.6251, 41.6743,
|
||||
41.7237, 41.7727, 41.8219, 41.8713, 41.9204, 41.9697, 42.0186, 42.0677,
|
||||
43.4945, 43.5425, 43.5902, 43.6376, 43.6856, 43.7334, 43.7808, 43.8289,
|
||||
43.8766, 43.9241, 43.9722, 44.02, 44.0675, 44.1157, 44.1635, 44.2111,
|
||||
44.2583, 44.3062, 44.3538, 44.4011, 44.449, 44.4966, 44.544, 44.5919,
|
||||
44.6396, 44.6869, 44.735, 44.7826, 44.8301, 44.8781, 44.9258, 44.9733,
|
||||
45.0213, 45.0691, 45.1166, 45.1647, 45.2125, 45.26, 45.3081, 45.356,
|
||||
45.4035, 45.4508, 45.4987, 45.5462, 45.5936, 45.6415, 45.6891, 45.7364,
|
||||
45.7844, 45.832, 45.8794, 45.9274, 45.9751, 46.0225, 46.0705, 46.1183,
|
||||
46.1657, 46.2138, 46.2615, 46.309, 46.3571, 46.4049, 46.4525, 46.5006,
|
||||
43.4125, 43.4603, 43.5077, 43.5549, 43.6027, 43.6502, 43.6974, 43.7453,
|
||||
43.7928, 43.84, 43.8879, 43.9355, 43.9828, 44.0307, 44.0783, 44.1256,
|
||||
44.1726, 44.2203, 44.2676, 44.3147, 44.3624, 44.4098, 44.4569, 44.5046,
|
||||
44.552, 44.5992, 44.6469, 44.6944, 44.7416, 44.7894, 44.8369, 44.8841,
|
||||
44.9319, 44.9795, 45.0267, 45.0746, 45.1222, 45.1694, 45.2173, 45.265,
|
||||
45.3123, 45.3593, 45.407, 45.4543, 45.5014, 45.5491, 45.5965, 45.6436,
|
||||
45.6913, 45.7387, 45.7859, 45.8336, 45.8811, 45.9283, 45.9761, 46.0236,
|
||||
46.0708, 46.1186, 46.1661, 46.2134, 46.2613, 46.3088, 46.3561, 46.404,
|
||||
34.7729, 34.8035, 34.8341, 34.8648, 34.8953, 34.9259, 34.9567, 34.9872,
|
||||
35.0179, 35.0486, 35.0792, 35.1098, 35.1404, 35.171, 35.2016, 35.2322,
|
||||
35.2628, 35.2935, 35.324, 35.3547, 35.3854, 35.416, 35.4466, 35.4774,
|
||||
35.508, 35.5387, 35.5692, 35.5998, 35.6305, 35.661, 35.6916, 35.7224,
|
||||
35.7529, 35.7836, 35.8143, 35.8449, 35.8755, 35.9061, 35.9367, 35.9674,
|
||||
35.9979, 36.0285, 36.0592, 36.0898, 36.1204, 36.1511, 36.1817, 36.2123,
|
||||
36.2431, 36.2737, 36.3044, 36.3349, 36.3655, 36.3962, 36.4267, 36.4574,
|
||||
36.4881, 36.5186, 36.5493, 36.58, 36.6106, 36.6413, 36.6718, 36.7024,
|
||||
34.6995, 34.7297, 34.76, 34.7904, 34.8206, 34.8509, 34.8813, 34.9115,
|
||||
34.9418, 34.9722, 35.0025, 35.0328, 35.063, 35.0933, 35.1237, 35.1539,
|
||||
35.1842, 35.2146, 35.2448, 35.2751, 35.3055, 35.3357, 35.3661, 35.3965,
|
||||
35.4268, 35.4571, 35.4873, 35.5176, 35.548, 35.5782, 35.6085, 35.6389,
|
||||
35.6691, 35.6994, 35.7298, 35.7601, 35.7904, 35.8206, 35.8509, 35.8813,
|
||||
35.9115, 35.9418, 35.9721, 36.0024, 36.0327, 36.0631, 36.0933, 36.1237,
|
||||
36.1541, 36.1843, 36.2147, 36.2449, 36.2752, 36.3056, 36.3358, 36.3661,
|
||||
36.3965, 36.4267, 36.457, 36.4874, 36.5177, 36.548, 36.5782, 36.6085,
|
||||
37.0829, 37.1127, 37.1423, 37.1717, 37.2015, 37.2312, 37.2607, 37.2905,
|
||||
37.3201, 37.3496, 37.3795, 37.4091, 37.4386, 37.4685, 37.4982, 37.5277,
|
||||
37.5571, 37.5868, 37.6164, 37.6458, 37.6755, 37.7051, 37.7346, 37.7643,
|
||||
37.7939, 37.8234, 37.8531, 37.8827, 37.9122, 37.942, 37.9716, 38.0011,
|
||||
38.0309, 38.0606, 38.0901, 38.1199, 38.1496, 38.1791, 38.209, 38.2387,
|
||||
38.2682, 38.2976, 38.3273, 38.3569, 38.3863, 38.416, 38.4456, 38.475,
|
||||
38.5048, 38.5344, 38.5638, 38.5936, 38.6232, 38.6527, 38.6825, 38.7121,
|
||||
38.7416, 38.7714, 38.8011, 38.8306, 38.8604, 38.8901, 38.9196, 38.9494,
|
||||
37.0359, 37.0655, 37.095, 37.1243, 37.154, 37.1835, 37.2129, 37.2425,
|
||||
37.2721, 37.3014, 37.3311, 37.3607, 37.39, 37.4198, 37.4493, 37.4787,
|
||||
37.508, 37.5376, 37.567, 37.5963, 37.6259, 37.6553, 37.6846, 37.7142,
|
||||
37.7437, 37.773, 37.8027, 37.8322, 37.8615, 37.8911, 37.9207, 37.95,
|
||||
37.9797, 38.0092, 38.0386, 38.0683, 38.0978, 38.1272, 38.1569, 38.1865,
|
||||
38.2159, 38.2451, 38.2747, 38.3042, 38.3334, 38.363, 38.3925, 38.4218,
|
||||
38.4514, 38.4809, 38.5102, 38.5398, 38.5693, 38.5986, 38.6283, 38.6578,
|
||||
38.6872, 38.7168, 38.7464, 38.7757, 38.8054, 38.835, 38.8644, 38.8941,
|
||||
39.1594, 39.2093, 39.2593, 39.3095, 39.3594, 39.4094, 39.4597, 39.5096,
|
||||
39.5597, 39.61, 39.6599, 39.7101, 39.7599, 39.8099, 39.8601, 39.91,
|
||||
39.96, 40.0102, 40.0601, 40.1102, 40.1605, 40.2104, 40.2605, 40.3108,
|
||||
40.3608, 40.411, 40.4608, 40.5108, 40.561, 40.6109, 40.661, 40.7112,
|
||||
40.7611, 40.8112, 40.8615, 40.9115, 40.9616, 41.0114, 41.0614, 41.1116,
|
||||
41.1615, 41.2115, 41.2617, 41.3116, 41.3617, 41.412, 41.4619, 41.512,
|
||||
41.5624, 41.6123, 41.6625, 41.7123, 41.7623, 41.8126, 41.8624, 41.9125,
|
||||
41.9627, 42.0127, 42.0628, 42.113, 42.163, 42.2131, 42.263, 42.313,
|
||||
39.0297, 39.079, 39.1284, 39.1781, 39.2274, 39.2769, 39.3265, 39.3759,
|
||||
39.4254, 39.4751, 39.5245, 39.5741, 39.6233, 39.6727, 39.7224, 39.7716,
|
||||
39.8211, 39.8708, 39.9201, 39.9696, 40.0193, 40.0686, 40.1182, 40.1679,
|
||||
40.2173, 40.2669, 40.3162, 40.3656, 40.4153, 40.4646, 40.514, 40.5637,
|
||||
40.6131, 40.6626, 40.7123, 40.7617, 40.8112, 40.8605, 40.9099, 40.9595,
|
||||
41.0088, 41.0583, 41.1079, 41.1573, 41.2068, 41.2565, 41.3058, 41.3554,
|
||||
41.4051, 41.4545, 41.5041, 41.5534, 41.6028, 41.6524, 41.7017, 41.7512,
|
||||
41.8009, 41.8502, 41.8998, 41.9495, 41.9988, 42.0484, 42.0977, 42.1471,
|
||||
43.5891, 43.6374, 43.6854, 43.7331, 43.7814, 43.8294, 43.8772, 43.9255,
|
||||
43.9736, 44.0214, 44.0698, 44.1179, 44.1657, 44.2141, 44.2623, 44.3101,
|
||||
44.3577, 44.4058, 44.4537, 44.5013, 44.5495, 44.5974, 44.6451, 44.6933,
|
||||
44.7413, 44.7889, 44.8372, 44.8852, 44.9329, 44.9812, 45.0293, 45.077,
|
||||
45.1254, 45.1734, 45.2212, 45.2696, 45.3177, 45.3655, 45.414, 45.4621,
|
||||
45.5099, 45.5575, 45.6057, 45.6535, 45.7011, 45.7493, 45.7973, 45.8449,
|
||||
45.8931, 45.9411, 45.9888, 46.037, 46.085, 46.1327, 46.1811, 46.2291,
|
||||
46.2768, 46.3252, 46.3733, 46.421, 46.4694, 46.5175, 46.5653, 46.6138,
|
||||
43.5064, 43.5544, 43.6022, 43.6497, 43.6978, 43.7456, 43.7931, 43.8412,
|
||||
43.889, 43.9366, 43.9847, 44.0326, 44.0802, 44.1284, 44.1763, 44.2239,
|
||||
44.2712, 44.3191, 44.3668, 44.4141, 44.4621, 44.5098, 44.5572, 44.6052,
|
||||
44.6529, 44.7004, 44.7484, 44.7962, 44.8436, 44.8918, 44.9395, 44.987,
|
||||
45.0352, 45.083, 45.1305, 45.1787, 45.2266, 45.2742, 45.3223, 45.3703,
|
||||
45.4179, 45.4652, 45.5131, 45.5608, 45.6081, 45.6561, 45.7038, 45.7512,
|
||||
45.7992, 45.8469, 45.8944, 45.9424, 45.9902, 46.0376, 46.0857, 46.1335,
|
||||
46.181, 46.2292, 46.277, 46.3245, 46.3727, 46.4206, 46.4682, 46.5164,
|
||||
};
|
||||
|
||||
void TestAttentionMultipleTokens() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 64;
|
||||
int num_kv_heads = 2;
|
||||
int num_heads = 4;
|
||||
int num_tokens = 2;
|
||||
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
|
||||
// will have 64 tokens to attend to.
|
||||
float att_cap = 10.0f;
|
||||
int layer_idx = 0;
|
||||
int layers_total = 1;
|
||||
int qbatch_size = 2;
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
|
||||
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
|
||||
num_heads, num_tokens, last_pos, att_cap, layer_idx,
|
||||
layers_total, qbatch_size, attention_impl);
|
||||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_d);
|
||||
|
||||
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
|
||||
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
|
||||
test_env.activations->attention, *test_env.qbatch,
|
||||
test_env.env, flags);
|
||||
|
||||
std::cerr << "att_out\n";
|
||||
PrintMatPtr(test_env.activations->attention.att_out);
|
||||
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
AttentionMultipleTokensAttentionGoldens.data() +
|
||||
i * test_env.activations->attention.att_out.Cols(),
|
||||
test_env.activations->attention.att_out.Row(i),
|
||||
test_env.activations->attention.att_out.Cols(), 1e-3,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "att_out mismatch for query: " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 34;
|
||||
int num_kv_heads = 2;
|
||||
int num_heads = 4;
|
||||
int num_tokens = 2;
|
||||
int last_pos = 31; // so in the tbatch token 0 will have 63 and token 1
|
||||
// will have 64 tokens to attend to.
|
||||
float att_cap = 10.0f;
|
||||
int layer_idx = 0;
|
||||
int layers_total = 1;
|
||||
int qbatch_size = 2;
|
||||
int attention_window_size = 32;
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
|
||||
AttentionTestEnv test_env(qkv_dim, kv_seq_len, attention_window_size,
|
||||
num_kv_heads, num_heads, num_tokens, last_pos,
|
||||
att_cap, layer_idx, layers_total, qbatch_size,
|
||||
attention_impl);
|
||||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_d);
|
||||
|
||||
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
|
||||
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
|
||||
test_env.activations->attention, *test_env.qbatch,
|
||||
test_env.env, flags);
|
||||
|
||||
std::cerr << "att_out\n";
|
||||
std::vector<float> att_out_golden_test_local = {
|
||||
39.3051, 39.3556, 39.4062, 39.4571, 39.5075, 39.5582, 39.6091, 39.6596,
|
||||
39.7103, 39.7612, 39.8118, 39.8626, 39.913, 39.9636, 40.0144, 40.0649,
|
||||
40.1155, 40.1664, 40.2169, 40.2676, 40.3185, 40.369, 40.4198, 40.4707,
|
||||
40.5213, 40.572, 40.6225, 40.6731, 40.724, 40.7744, 40.8251, 40.876,
|
||||
40.9265, 40.9772, 41.0281, 41.0787, 41.1295, 41.1799, 41.2305, 41.2813,
|
||||
41.3318, 41.3824, 41.4333, 41.4838, 41.5345, 41.5854, 41.6359, 41.6867,
|
||||
41.7376, 41.7882, 41.839, 41.8894, 41.94, 41.9908, 42.0413, 42.092,
|
||||
42.1429, 42.1934, 42.2441, 42.295, 42.3456, 42.3964, 42.4468, 42.4974,
|
||||
39.1614, 39.2113, 39.2613, 39.3114, 39.3613, 39.4113, 39.4616, 39.5115,
|
||||
39.5616, 39.6118, 39.6618, 39.7119, 39.7617, 39.8117, 39.8618, 39.9117,
|
||||
39.9617, 40.0119, 40.0618, 40.1118, 40.1621, 40.212, 40.2621, 40.3124,
|
||||
40.3623, 40.4125, 40.4623, 40.5123, 40.5625, 40.6123, 40.6624, 40.7126,
|
||||
40.7625, 40.8126, 40.8629, 40.9128, 40.9629, 41.0127, 41.0627, 41.1129,
|
||||
41.1627, 41.2127, 41.2629, 41.3128, 41.3629, 41.4131, 41.463, 41.5131,
|
||||
41.5634, 41.6134, 41.6635, 41.7133, 41.7634, 41.8135, 41.8634, 41.9134,
|
||||
41.9637, 42.0135, 42.0636, 42.1139, 42.1638, 42.214, 42.2637, 42.3137,
|
||||
43.8459, 43.895, 43.9437, 43.9921, 44.0411, 44.0898, 44.1383, 44.1874,
|
||||
44.2361, 44.2846, 44.3337, 44.3825, 44.4311, 44.4802, 44.529, 44.5776,
|
||||
44.6258, 44.6747, 44.7233, 44.7716, 44.8205, 44.8692, 44.9175, 44.9665,
|
||||
45.0151, 45.0635, 45.1125, 45.1612, 45.2096, 45.2586, 45.3074, 45.3558,
|
||||
45.4049, 45.4537, 45.5021, 45.5513, 45.6001, 45.6486, 45.6977, 45.7466,
|
||||
45.7951, 45.8434, 45.8923, 45.9409, 45.9891, 46.0381, 46.0867, 46.135,
|
||||
46.184, 46.2327, 46.281, 46.33, 46.3787, 46.4271, 46.4762, 46.5249,
|
||||
46.5733, 46.6224, 46.6712, 46.7197, 46.7688, 46.8176, 46.8661, 46.9153,
|
||||
43.7538, 43.8026, 43.851, 43.8992, 43.948, 43.9964, 44.0446, 44.0934,
|
||||
44.142, 44.1902, 44.239, 44.2876, 44.3358, 44.3847, 44.4333, 44.4816,
|
||||
44.5296, 44.5782, 44.6266, 44.6746, 44.7232, 44.7716, 44.8197, 44.8684,
|
||||
44.9168, 44.9649, 45.0136, 45.0621, 45.1102, 45.159, 45.2075, 45.2557,
|
||||
45.3045, 45.353, 45.4012, 45.4501, 45.4986, 45.5469, 45.5958, 45.6444,
|
||||
45.6927, 45.7406, 45.7893, 45.8376, 45.8856, 45.9343, 45.9827, 46.0307,
|
||||
46.0794, 46.1278, 46.1759, 46.2247, 46.2731, 46.3213, 46.3701, 46.4185,
|
||||
46.4667, 46.5155, 46.564, 46.6123, 46.6611, 46.7097, 46.7579, 46.8068,
|
||||
48.7531, 48.8438, 48.9348, 49.0262, 49.1169, 49.208, 49.2995, 49.3903,
|
||||
49.4815, 49.573, 49.6639, 49.7552, 49.8458, 49.9368, 50.0281, 50.1188,
|
||||
50.2099, 50.3013, 50.3921, 50.4832, 50.5747, 50.6656, 50.7568, 50.8484,
|
||||
50.9393, 51.0306, 51.1213, 51.2123, 51.3037, 51.3944, 51.4855, 51.577,
|
||||
51.6678, 51.759, 51.8505, 51.9414, 52.0327, 52.1233, 52.2143, 52.3056,
|
||||
52.3963, 52.4874, 52.5788, 52.6696, 52.7607, 52.8522, 52.9431, 53.0343,
|
||||
53.1259, 53.2168, 53.3081, 53.3988, 53.4898, 53.5812, 53.6719, 53.763,
|
||||
53.8545, 53.9453, 54.0365, 54.128, 54.2189, 54.3102, 54.4008, 54.4918,
|
||||
48.4943, 48.5838, 48.6737, 48.7639, 48.8535, 48.9435, 49.0338, 49.1235,
|
||||
49.2135, 49.3039, 49.3937, 49.4838, 49.5732, 49.6631, 49.7533, 49.8428,
|
||||
49.9328, 50.023, 50.1127, 50.2027, 50.293, 50.3827, 50.4728, 50.5632,
|
||||
50.653, 50.7432, 50.8327, 50.9226, 51.0128, 51.1024, 51.1924, 51.2827,
|
||||
51.3724, 51.4624, 51.5528, 51.6425, 51.7327, 51.8221, 51.912, 52.0022,
|
||||
52.0917, 52.1817, 52.2719, 52.3616, 52.4516, 52.5419, 52.6316, 52.7217,
|
||||
52.8121, 52.9019, 52.9921, 53.0816, 53.1715, 53.2617, 53.3513, 53.4413,
|
||||
53.5316, 53.6212, 53.7113, 53.8017, 53.8914, 53.9815, 54.071, 54.1609,
|
||||
57.7208, 57.8084, 57.8954, 57.9818, 58.0694, 58.1564, 58.2429, 58.3306,
|
||||
58.4177, 58.5043, 58.5921, 58.6793, 58.7659, 58.8537, 58.941, 59.0277,
|
||||
59.1137, 59.2011, 59.2878, 59.374, 59.4614, 59.5482, 59.6345, 59.722,
|
||||
59.8089, 59.8952, 59.9827, 60.0697, 60.1561, 60.2437, 60.3308, 60.4172,
|
||||
60.505, 60.5921, 60.6786, 60.7664, 60.8536, 60.9402, 61.0281, 61.1153,
|
||||
61.202, 61.2881, 61.3755, 61.4622, 61.5483, 61.6358, 61.7226, 61.8088,
|
||||
61.8963, 61.9832, 62.0695, 62.1571, 62.244, 62.3304, 62.4181, 62.5051,
|
||||
62.5916, 62.6793, 62.7664, 62.853, 62.9407, 63.0279, 63.1146, 63.2024,
|
||||
57.5554, 57.6426, 57.729, 57.815, 57.9021, 57.9887, 58.0747, 58.162,
|
||||
58.2486, 58.3347, 58.422, 58.5087, 58.5949, 58.6823, 58.7691, 58.8553,
|
||||
58.9409, 59.0278, 59.114, 59.1997, 59.2867, 59.373, 59.4588, 59.5458,
|
||||
59.6323, 59.7181, 59.8052, 59.8917, 59.9776, 60.0648, 60.1514, 60.2374,
|
||||
60.3246, 60.4113, 60.4974, 60.5847, 60.6714, 60.7576, 60.8449, 60.9317,
|
||||
61.018, 61.1036, 61.1905, 61.2767, 61.3624, 61.4494, 61.5357, 61.6215,
|
||||
61.7085, 61.7949, 61.8808, 61.9679, 62.0544, 62.1403, 62.2275, 62.3141,
|
||||
62.4001, 62.4873, 62.574, 62.66, 62.7474, 62.8341, 62.9202, 63.0076,
|
||||
39.3678, 39.4186, 39.4696, 39.5207, 39.5715, 39.6225, 39.6737, 39.7246,
|
||||
39.7756, 39.8268, 39.8777, 39.9288, 39.9796, 40.0305, 40.0816, 40.1324,
|
||||
40.1834, 40.2346, 40.2854, 40.3364, 40.3876, 40.4385, 40.4896, 40.5408,
|
||||
40.5917, 40.6428, 40.6936, 40.7446, 40.7957, 40.8466, 40.8975, 40.9487,
|
||||
40.9996, 41.0506, 41.1019, 41.1528, 41.2038, 41.2546, 41.3055, 41.3567,
|
||||
41.4075, 41.4584, 41.5096, 41.5605, 41.6115, 41.6627, 41.7136, 41.7646,
|
||||
41.8159, 41.8668, 41.9179, 41.9687, 42.0196, 42.0708, 42.1216, 42.1726,
|
||||
42.2238, 42.2746, 42.3256, 42.3769, 42.4278, 42.4789, 42.5296, 42.5806,
|
||||
39.2228, 39.2729, 39.3232, 39.3737, 39.4239, 39.4743, 39.5248, 39.575,
|
||||
39.6254, 39.676, 39.7263, 39.7767, 39.8268, 39.8771, 39.9276, 39.9778,
|
||||
40.0281, 40.0786, 40.1288, 40.1792, 40.2298, 40.28, 40.3304, 40.381,
|
||||
40.4313, 40.4818, 40.5319, 40.5822, 40.6327, 40.6829, 40.7333, 40.7838,
|
||||
40.834, 40.8844, 40.935, 40.9853, 41.0357, 41.0858, 41.1361, 41.1866,
|
||||
41.2368, 41.2871, 41.3376, 41.3878, 41.4382, 41.4888, 41.539, 41.5894,
|
||||
41.64, 41.6903, 41.7408, 41.7909, 41.8412, 41.8917, 41.9419, 41.9922,
|
||||
42.0428, 42.093, 42.1434, 42.194, 42.2442, 42.2947, 42.3448, 42.3951,
|
||||
43.9435, 43.9928, 44.0418, 44.0905, 44.1399, 44.1889, 44.2376, 44.287,
|
||||
44.3361, 44.3849, 44.4343, 44.4834, 44.5322, 44.5817, 44.6308, 44.6797,
|
||||
44.7283, 44.7774, 44.8263, 44.8749, 44.9241, 44.9731, 45.0217, 45.071,
|
||||
45.12, 45.1686, 45.2179, 45.2669, 45.3156, 45.365, 45.414, 45.4628,
|
||||
45.5122, 45.5613, 45.61, 45.6595, 45.7086, 45.7574, 45.8068, 45.856,
|
||||
45.9048, 45.9534, 46.0026, 46.0515, 46.1001, 46.1493, 46.1982, 46.2469,
|
||||
46.2961, 46.3451, 46.3938, 46.4431, 46.4921, 46.5408, 46.5901, 46.6392,
|
||||
46.6879, 46.7373, 46.7864, 46.8352, 46.8846, 46.9337, 46.9825, 47.032,
|
||||
43.8506, 43.8996, 43.9484, 43.9968, 44.0459, 44.0947, 44.1432, 44.1923,
|
||||
44.2411, 44.2896, 44.3388, 44.3876, 44.4362, 44.4854, 44.5343, 44.5829,
|
||||
44.6312, 44.6801, 44.7287, 44.7771, 44.826, 44.8747, 44.9231, 44.9721,
|
||||
45.0208, 45.0692, 45.1182, 45.167, 45.2154, 45.2645, 45.3133, 45.3617,
|
||||
45.4109, 45.4597, 45.5082, 45.5574, 45.6062, 45.6548, 45.704, 45.7529,
|
||||
45.8015, 45.8498, 45.8987, 45.9473, 45.9957, 46.0446, 46.0933, 46.1416,
|
||||
46.1906, 46.2394, 46.2878, 46.3368, 46.3856, 46.434, 46.4831, 46.5319,
|
||||
46.5803, 46.6295, 46.6783, 46.7268, 46.776, 46.8248, 46.8734, 46.9226,
|
||||
48.8777, 48.969, 49.0607, 49.1527, 49.2441, 49.3358, 49.4279, 49.5194,
|
||||
49.6112, 49.7034, 49.7949, 49.8868, 49.9781, 50.0697, 50.1617, 50.2531,
|
||||
50.3448, 50.4368, 50.5283, 50.62, 50.7122, 50.8037, 50.8956, 50.9878,
|
||||
51.0794, 51.1713, 51.2626, 51.3543, 51.4463, 51.5377, 51.6294, 51.7215,
|
||||
51.813, 51.9048, 51.997, 52.0885, 52.1805, 52.2717, 52.3633, 52.4553,
|
||||
52.5467, 52.6384, 52.7305, 52.8219, 52.9137, 53.0058, 53.0973, 53.1892,
|
||||
53.2814, 53.373, 53.4649, 53.5562, 53.6479, 53.7399, 53.8313, 53.923,
|
||||
54.0152, 54.1066, 54.1984, 54.2906, 54.3821, 54.4741, 54.5653, 54.6569,
|
||||
48.6164, 48.7066, 48.7971, 48.888, 48.9782, 49.0688, 49.1597, 49.25,
|
||||
49.3407, 49.4317, 49.5221, 49.6129, 49.703, 49.7934, 49.8843, 49.9745,
|
||||
50.065, 50.1559, 50.2462, 50.3368, 50.4278, 50.5181, 50.6089, 50.6999,
|
||||
50.7903, 50.8811, 50.9713, 51.0618, 51.1527, 51.2429, 51.3335, 51.4244,
|
||||
51.5147, 51.6054, 51.6964, 51.7868, 51.8776, 51.9677, 52.0581, 52.149,
|
||||
52.2392, 52.3297, 52.4206, 52.5109, 52.6015, 52.6925, 52.7828, 52.8736,
|
||||
52.9646, 53.055, 53.1458, 53.236, 53.3265, 53.4174, 53.5076, 53.5982,
|
||||
53.6891, 53.7794, 53.8701, 53.9611, 54.0515, 54.1423, 54.2324, 54.3228,
|
||||
57.914, 58.0021, 58.0897, 58.1767, 58.265, 58.3526, 58.4397, 58.528,
|
||||
58.6157, 58.7028, 58.7912, 58.879, 58.9662, 59.0547, 59.1426, 59.2299,
|
||||
59.3165, 59.4045, 59.4918, 59.5786, 59.6666, 59.754, 59.8408, 59.9289,
|
||||
60.0165, 60.1033, 60.1915, 60.2791, 60.3661, 60.4544, 60.542, 60.629,
|
||||
60.7174, 60.8051, 60.8922, 60.9806, 61.0684, 61.1556, 61.2441, 61.332,
|
||||
61.4193, 61.5059, 61.5939, 61.6812, 61.768, 61.856, 61.9434, 62.0302,
|
||||
62.1183, 62.2059, 62.2927, 62.3809, 62.4685, 62.5555, 62.6437, 62.7314,
|
||||
62.8184, 62.9068, 62.9945, 63.0816, 63.17, 63.2578, 63.345, 63.4335,
|
||||
57.7471, 57.8348, 57.9219, 58.0084, 58.0962, 58.1834, 58.27, 58.3578,
|
||||
58.4451, 58.5317, 58.6197, 58.707, 58.7937, 58.8817, 58.9691, 59.0559,
|
||||
59.1421, 59.2296, 59.3165, 59.4028, 59.4903, 59.5773, 59.6636, 59.7512,
|
||||
59.8383, 59.9247, 60.0124, 60.0995, 60.186, 60.2738, 60.361, 60.4476,
|
||||
60.5354, 60.6227, 60.7093, 60.7973, 60.8846, 60.9713, 61.0593, 61.1467,
|
||||
61.2335, 61.3197, 61.4072, 61.4941, 61.5804, 61.6679, 61.7549, 61.8412,
|
||||
61.9289, 62.0159, 62.1023, 62.19, 62.2772, 62.3636, 62.4514, 62.5386,
|
||||
62.6252, 62.7131, 62.8003, 62.887, 62.9749, 63.0622, 63.1489, 63.237};
|
||||
PrintMatPtr(test_env.activations->attention.att_out);
|
||||
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
att_out_golden_test_local.data() +
|
||||
i * test_env.activations->attention.att_out.Cols(),
|
||||
test_env.activations->attention.att_out.Row(i),
|
||||
test_env.activations->attention.att_out.Cols(), 1e-3,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "att_out mismatch for query: " << i;
|
||||
}
|
||||
}
|
||||
|
||||
void TestAttentionMultipleTokensBF16() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 64;
|
||||
int num_kv_heads = 2;
|
||||
int num_heads = 4;
|
||||
int num_tokens = 2;
|
||||
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
|
||||
// will have 64 tokens to attend to.
|
||||
float att_cap = 10.0f;
|
||||
int layer_idx = 0;
|
||||
int layers_total = 1;
|
||||
int qbatch_size = 2;
|
||||
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16;
|
||||
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
|
||||
num_heads, num_tokens, last_pos, att_cap, layer_idx,
|
||||
layers_total, qbatch_size, attention_impl);
|
||||
test_env.SetupWeights();
|
||||
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
|
||||
FillMatPtrT(test_env.activations->attention.q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_Q);
|
||||
FillMatPtrT(test_env.activations->attention.vit_K);
|
||||
FillMatPtrT(test_env.activations->attention.att);
|
||||
FillMatPtrT(test_env.activations->attention.att_out);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_max);
|
||||
FillMatPtrT(test_env.activations->attention.softmax_d);
|
||||
|
||||
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
|
||||
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
|
||||
test_env.activations->attention, *test_env.qbatch,
|
||||
test_env.env, flags);
|
||||
std::cerr << "att_out\n";
|
||||
PrintMatPtr(test_env.activations->attention.att_out);
|
||||
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
|
||||
EXPECT_TRUE(hwy::CompareArraySimilar(
|
||||
AttentionMultipleTokensAttentionGoldens.data() +
|
||||
i * test_env.activations->attention.att_out.Cols(),
|
||||
test_env.activations->attention.att_out.Row(i),
|
||||
test_env.activations->attention.att_out.Cols(), 1e-1,
|
||||
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
|
||||
<< "att_out mismatch for query: " << i;
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(TiledAttentionTest);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestTransposeStridedQueries);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
||||
TestLocalAttentionForAllHeadsTokensAndBatch);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokens);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest, TestAttentionMultipleTokensBF16);
|
||||
HWY_EXPORT_AND_TEST_P(TiledAttentionTest,
|
||||
TestAttentionMultipleTokensAttentionWindowSizeEdgeCase);
|
||||
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
444
ops/ops-inl.h
444
ops/ops-inl.h
|
|
@ -1026,6 +1026,450 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void MulByConstAndAddTile4(
|
|||
}
|
||||
HWY_DASSERT(size == i);
|
||||
}
|
||||
template <int32_t N, typename DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void StoreUpTo8Times2(DF df, MatPtrT<float>& out,
|
||||
size_t start_col, VF out0_0, VF out0_1,
|
||||
VF out1_0, VF out1_1, VF out2_0,
|
||||
VF out2_1, VF out3_0, VF out3_1,
|
||||
VF out4_0, VF out4_1, VF out5_0,
|
||||
VF out5_1, VF out6_0, VF out6_1,
|
||||
VF out7_0, VF out7_1) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
hn::Store(out0_0, df, out.Row(0) + start_col);
|
||||
hn::Store(out0_1, df, out.Row(0) + start_col + NF);
|
||||
if constexpr (N >= 2) {
|
||||
hn::Store(out1_0, df, out.Row(1) + start_col);
|
||||
hn::Store(out1_1, df, out.Row(1) + start_col + NF);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
hn::Store(out2_0, df, out.Row(2) + start_col);
|
||||
hn::Store(out2_1, df, out.Row(2) + start_col + NF);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
hn::Store(out3_0, df, out.Row(3) + start_col);
|
||||
hn::Store(out3_1, df, out.Row(3) + start_col + NF);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
hn::Store(out4_0, df, out.Row(4) + start_col);
|
||||
hn::Store(out4_1, df, out.Row(4) + start_col + NF);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
hn::Store(out5_0, df, out.Row(5) + start_col);
|
||||
hn::Store(out5_1, df, out.Row(5) + start_col + NF);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
hn::Store(out6_0, df, out.Row(6) + start_col);
|
||||
hn::Store(out6_1, df, out.Row(6) + start_col + NF);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
hn::Store(out7_0, df, out.Row(7) + start_col);
|
||||
hn::Store(out7_1, df, out.Row(7) + start_col + NF);
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, typename DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void LoadAndMulUpTo8Times2(
|
||||
DF df, MatPtrT<float>& out, size_t column, const float* HWY_RESTRICT scales,
|
||||
VF& out0_0, VF& out0_1, VF& out1_0, VF& out1_1, VF& out2_0, VF& out2_1,
|
||||
VF& out3_0, VF& out3_1, VF& out4_0, VF& out4_1, VF& out5_0, VF& out5_1,
|
||||
VF& out6_0, VF& out6_1, VF& out7_0, VF& out7_1) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
out0_0 = hn::Load(df, out.Row(0) + column);
|
||||
out0_0 = hn::Mul(out0_0, hn::Set(df, scales[0]));
|
||||
out0_1 = hn::Load(df, out.Row(0) + column + NF);
|
||||
out0_1 = hn::Mul(out0_1, hn::Set(df, scales[0]));
|
||||
if constexpr (N >= 2) {
|
||||
out1_0 = hn::Load(df, out.Row(1) + column);
|
||||
out1_0 = hn::Mul(out1_0, hn::Set(df, scales[1]));
|
||||
out1_1 = hn::Load(df, out.Row(1) + column + NF);
|
||||
out1_1 = hn::Mul(out1_1, hn::Set(df, scales[1]));
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
out2_0 = hn::Load(df, out.Row(2) + column);
|
||||
out2_0 = hn::Mul(out2_0, hn::Set(df, scales[2]));
|
||||
out2_1 = hn::Load(df, out.Row(2) + column + NF);
|
||||
out2_1 = hn::Mul(out2_1, hn::Set(df, scales[2]));
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
out3_0 = hn::Load(df, out.Row(3) + column);
|
||||
out3_0 = hn::Mul(out3_0, hn::Set(df, scales[3]));
|
||||
out3_1 = hn::Load(df, out.Row(3) + column + NF);
|
||||
out3_1 = hn::Mul(out3_1, hn::Set(df, scales[3]));
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
out4_0 = hn::Load(df, out.Row(4) + column);
|
||||
out4_0 = hn::Mul(out4_0, hn::Set(df, scales[4]));
|
||||
out4_1 = hn::Load(df, out.Row(4) + column + NF);
|
||||
out4_1 = hn::Mul(out4_1, hn::Set(df, scales[4]));
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
out5_0 = hn::Load(df, out.Row(5) + column);
|
||||
out5_0 = hn::Mul(out5_0, hn::Set(df, scales[5]));
|
||||
out5_1 = hn::Load(df, out.Row(5) + column + NF);
|
||||
out5_1 = hn::Mul(out5_1, hn::Set(df, scales[5]));
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
out6_0 = hn::Load(df, out.Row(6) + column);
|
||||
out6_0 = hn::Mul(out6_0, hn::Set(df, scales[6]));
|
||||
out6_1 = hn::Load(df, out.Row(6) + column + NF);
|
||||
out6_1 = hn::Mul(out6_1, hn::Set(df, scales[6]));
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
out7_0 = hn::Load(df, out.Row(7) + column);
|
||||
out7_0 = hn::Mul(out7_0, hn::Set(df, scales[7]));
|
||||
out7_1 = hn::Load(df, out.Row(7) + column + NF);
|
||||
out7_1 = hn::Mul(out7_1, hn::Set(df, scales[7]));
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF& c0_p0, const VF& c0_p1,
|
||||
const VF& c1_p0, const VF& c1_p1, const VF& c2_p0, const VF& c2_p1,
|
||||
const VF& c3_p0, const VF& c3_p1, const VF& c4_p0, const VF& c4_p1,
|
||||
const VF& c5_p0, const VF& c5_p1, const VF& c6_p0, const VF& c6_p1,
|
||||
const VF& c7_p0, const VF& c7_p1, const VType* HWY_RESTRICT v_tile,
|
||||
MatPtrT<float>& out) {
|
||||
static_assert(N <= 8);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const size_t qkv_dim = out.Cols();
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
PackedSpan<const VType> v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF);
|
||||
|
||||
size_t i = 0;
|
||||
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
|
||||
HWY_ALIGN float consts_buffer[kMaxLanes * N * 2];
|
||||
hn::Store(c0_p0, df, consts_buffer);
|
||||
hn::Store(c0_p1, df, consts_buffer + kMaxLanes);
|
||||
if constexpr (N >= 2) {
|
||||
hn::Store(c1_p0, df, consts_buffer + 2 * kMaxLanes);
|
||||
hn::Store(c1_p1, df, consts_buffer + 3 * kMaxLanes);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
hn::Store(c2_p0, df, consts_buffer + 4 * kMaxLanes);
|
||||
hn::Store(c2_p1, df, consts_buffer + 5 * kMaxLanes);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
hn::Store(c3_p0, df, consts_buffer + 6 * kMaxLanes);
|
||||
hn::Store(c3_p1, df, consts_buffer + 7 * kMaxLanes);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
hn::Store(c4_p0, df, consts_buffer + 8 * kMaxLanes);
|
||||
hn::Store(c4_p1, df, consts_buffer + 9 * kMaxLanes);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
hn::Store(c5_p0, df, consts_buffer + 10 * kMaxLanes);
|
||||
hn::Store(c5_p1, df, consts_buffer + 11 * kMaxLanes);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
hn::Store(c6_p0, df, consts_buffer + 12 * kMaxLanes);
|
||||
hn::Store(c6_p1, df, consts_buffer + 13 * kMaxLanes);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
hn::Store(c7_p0, df, consts_buffer + 14 * kMaxLanes);
|
||||
hn::Store(c7_p1, df, consts_buffer + 15 * kMaxLanes);
|
||||
}
|
||||
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
|
||||
while (i + NF * 2 <= qkv_dim) {
|
||||
VF out0_0, out1_0, out2_0, out3_0, out4_0, out5_0, out6_0, out7_0;
|
||||
VF out0_1, out1_1, out2_1, out3_1, out4_1, out5_1, out6_1, out7_1;
|
||||
LoadAndMulUpTo8Times2<N>(df, out, i, scales, out0_0, out0_1, out1_0, out1_1,
|
||||
out2_0, out2_1, out3_0, out3_1, out4_0, out4_1,
|
||||
out5_0, out5_1, out6_0, out6_1, out7_0, out7_1);
|
||||
for (int lane = 0; lane < NF; ++lane) {
|
||||
VF xI1, xI2;
|
||||
Decompress2(df, v_span, qkv_dim * lane + i, xI1, xI2);
|
||||
|
||||
out0_0 = hn::MulAdd(xI1, hn::Set(df, consts_buffer[lane + 0 * kMaxLanes]),
|
||||
out0_0);
|
||||
out0_1 = hn::MulAdd(xI2, hn::Set(df, consts_buffer[lane + 0 * kMaxLanes]),
|
||||
out0_1);
|
||||
if constexpr (N >= 2) {
|
||||
out1_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 2 * kMaxLanes]), out1_0);
|
||||
out1_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 2 * kMaxLanes]), out1_1);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
out2_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 4 * kMaxLanes]), out2_0);
|
||||
out2_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 4 * kMaxLanes]), out2_1);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
out3_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 6 * kMaxLanes]), out3_0);
|
||||
out3_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 6 * kMaxLanes]), out3_1);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
out4_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 8 * kMaxLanes]), out4_0);
|
||||
out4_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 8 * kMaxLanes]), out4_1);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
out5_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 10 * kMaxLanes]), out5_0);
|
||||
out5_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 10 * kMaxLanes]), out5_1);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
out6_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 12 * kMaxLanes]), out6_0);
|
||||
out6_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 12 * kMaxLanes]), out6_1);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
out7_0 = hn::MulAdd(
|
||||
xI1, hn::Set(df, consts_buffer[lane + 14 * kMaxLanes]), out7_0);
|
||||
out7_1 = hn::MulAdd(
|
||||
xI2, hn::Set(df, consts_buffer[lane + 14 * kMaxLanes]), out7_1);
|
||||
}
|
||||
VF xI3, xI4;
|
||||
Decompress2(df, v_span, qkv_dim * (NF + lane) + i, xI3, xI4);
|
||||
|
||||
out0_0 = hn::MulAdd(xI3, hn::Set(df, consts_buffer[lane + 1 * kMaxLanes]),
|
||||
out0_0);
|
||||
out0_1 = hn::MulAdd(xI4, hn::Set(df, consts_buffer[lane + 1 * kMaxLanes]),
|
||||
out0_1);
|
||||
if constexpr (N >= 2) {
|
||||
out1_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 3 * kMaxLanes]), out1_0);
|
||||
out1_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 3 * kMaxLanes]), out1_1);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
out2_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 5 * kMaxLanes]), out2_0);
|
||||
out2_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 5 * kMaxLanes]), out2_1);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
out3_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 7 * kMaxLanes]), out3_0);
|
||||
out3_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 7 * kMaxLanes]), out3_1);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
out4_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 9 * kMaxLanes]), out4_0);
|
||||
out4_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 9 * kMaxLanes]), out4_1);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
out5_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 11 * kMaxLanes]), out5_0);
|
||||
out5_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 11 * kMaxLanes]), out5_1);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
out6_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 13 * kMaxLanes]), out6_0);
|
||||
out6_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 13 * kMaxLanes]), out6_1);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
out7_0 = hn::MulAdd(
|
||||
xI3, hn::Set(df, consts_buffer[lane + 15 * kMaxLanes]), out7_0);
|
||||
out7_1 = hn::MulAdd(
|
||||
xI4, hn::Set(df, consts_buffer[lane + 15 * kMaxLanes]), out7_1);
|
||||
}
|
||||
}
|
||||
StoreUpTo8Times2<N>(df, out, i, out0_0, out0_1, out1_0, out1_1, out2_0,
|
||||
out2_1, out3_0, out3_1, out4_0, out4_1, out5_0, out5_1,
|
||||
out6_0, out6_1, out7_0, out7_1);
|
||||
|
||||
i += 2 * NF;
|
||||
}
|
||||
HWY_DASSERT(qkv_dim == i);
|
||||
}
|
||||
|
||||
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
|
||||
DF df, const float* HWY_RESTRICT scales, VF c0_p0, VF c0_p1, VF c1_p0,
|
||||
VF c1_p1, VF c2_p0, VF c2_p1, VF c3_p0, VF c3_p1, VF c4_p0, VF c4_p1,
|
||||
VF c5_p0, VF c5_p1, VF c6_p0, VF c6_p1, VF c7_p0, VF c7_p1,
|
||||
VType* HWY_RESTRICT v_tile, MatPtrT<float>& out) {
|
||||
static_assert(N <= 8);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const size_t qkv_dim = out.Cols();
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
using DBF = hn::ScalableTag<BF16>;
|
||||
const DBF dbf;
|
||||
using VBF = hn::Vec<DBF>;
|
||||
PackedSpan<const VType> v_span = MakeConstSpan(v_tile, qkv_dim * 2 * NF);
|
||||
HWY_ALIGN BF16 cs[N * kMaxLanes * 2];
|
||||
PackedSpan<BF16> cs_span = MakeSpan(cs, N * kMaxLanes * 2);
|
||||
float* cs_as_float = HWY_RCAST_ALIGNED(float*, cs);
|
||||
Compress2(df, c0_p0, c0_p1, cs_span, 0);
|
||||
if constexpr (N >= 2) {
|
||||
Compress2(df, c1_p0, c1_p1, cs_span, kMaxLanes * 2);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
Compress2(df, c2_p0, c2_p1, cs_span, 2 * kMaxLanes * 2);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
Compress2(df, c3_p0, c3_p1, cs_span, 3 * kMaxLanes * 2);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
Compress2(df, c4_p0, c4_p1, cs_span, 4 * kMaxLanes * 2);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
Compress2(df, c5_p0, c5_p1, cs_span, 5 * kMaxLanes * 2);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
Compress2(df, c6_p0, c6_p1, cs_span, 6 * kMaxLanes * 2);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
Compress2(df, c7_p0, c7_p1, cs_span, 7 * kMaxLanes * 2);
|
||||
}
|
||||
VF zero = hn::Zero(df);
|
||||
size_t i = 0;
|
||||
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
|
||||
while (i + NF * 2 <= qkv_dim) {
|
||||
VF out0_0, out1_0, out2_0, out3_0;
|
||||
VF out0_1, out1_1, out2_1, out3_1;
|
||||
VF out4_0, out5_0, out6_0, out7_0;
|
||||
VF out4_1, out5_1, out6_1, out7_1;
|
||||
VF helper_out0_0 = hn::Zero(df), helper_out0_1 = hn::Zero(df),
|
||||
helper_out1_0 = hn::Zero(df), helper_out1_1 = hn::Zero(df),
|
||||
helper_out2_0 = hn::Zero(df), helper_out2_1 = hn::Zero(df),
|
||||
helper_out3_0 = hn::Zero(df), helper_out3_1 = hn::Zero(df),
|
||||
helper_out4_0 = hn::Zero(df), helper_out4_1 = hn::Zero(df),
|
||||
helper_out5_0 = hn::Zero(df), helper_out5_1 = hn::Zero(df),
|
||||
helper_out6_0 = hn::Zero(df), helper_out6_1 = hn::Zero(df),
|
||||
helper_out7_0 = hn::Zero(df), helper_out7_1 = hn::Zero(df);
|
||||
LoadAndMulUpTo8Times2<N>(df, out, i, scales, out0_0, out0_1, out1_0, out1_1,
|
||||
out2_0, out2_1, out3_0, out3_1, out4_0, out4_1,
|
||||
out5_0, out5_1, out6_0, out6_1, out7_0, out7_1);
|
||||
for (int lane = 0; lane < NF; ++lane) {
|
||||
VBF xI, xI2;
|
||||
Decompress2(dbf, v_span, 2 * qkv_dim * lane + i * 2, xI, xI2);
|
||||
|
||||
// Set pair of c scales for 2 value vectors
|
||||
out0_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI, hn::BitCast(dbf, hn::Set(df, cs_as_float[lane])), out0_0,
|
||||
helper_out0_0);
|
||||
out0_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2, hn::BitCast(dbf, hn::Set(df, cs_as_float[lane])), out0_1,
|
||||
helper_out0_1);
|
||||
if constexpr (N >= 2) {
|
||||
out1_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + kMaxLanes])),
|
||||
out1_0, helper_out1_0);
|
||||
out1_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + kMaxLanes])),
|
||||
out1_1, helper_out1_1);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
out2_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 2 * kMaxLanes])),
|
||||
out2_0, helper_out2_0);
|
||||
out2_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 2 * kMaxLanes])),
|
||||
out2_1, helper_out2_1);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
out3_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 3 * kMaxLanes])),
|
||||
out3_0, helper_out3_0);
|
||||
out3_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 3 * kMaxLanes])),
|
||||
out3_1, helper_out3_1);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
out4_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 4 * kMaxLanes])),
|
||||
out4_0, helper_out4_0);
|
||||
out4_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 4 * kMaxLanes])),
|
||||
out4_1, helper_out4_1);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
out5_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 5 * kMaxLanes])),
|
||||
out5_0, helper_out5_0);
|
||||
out5_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 5 * kMaxLanes])),
|
||||
out5_1, helper_out5_1);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
out6_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 6 * kMaxLanes])),
|
||||
out6_0, helper_out6_0);
|
||||
out6_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 6 * kMaxLanes])),
|
||||
out6_1, helper_out6_1);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
out7_0 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 7 * kMaxLanes])),
|
||||
out7_0, helper_out7_0);
|
||||
out7_1 = hn::ReorderWidenMulAccumulate(
|
||||
df, xI2,
|
||||
hn::BitCast(dbf, hn::Set(df, cs_as_float[lane + 7 * kMaxLanes])),
|
||||
out7_1, helper_out7_1);
|
||||
}
|
||||
}
|
||||
#if HWY_NATIVE_DOT_BF16 == 0
|
||||
out0_0 = hn::Add(out0_0, helper_out0_0);
|
||||
out0_1 = hn::Add(out0_1, helper_out0_1);
|
||||
if constexpr (N >= 2) {
|
||||
out1_0 = hn::Add(out1_0, helper_out1_0);
|
||||
out1_1 = hn::Add(out1_1, helper_out1_1);
|
||||
}
|
||||
if constexpr (N >= 3) {
|
||||
out2_0 = hn::Add(out2_0, helper_out2_0);
|
||||
out2_1 = hn::Add(out2_1, helper_out2_1);
|
||||
}
|
||||
if constexpr (N >= 4) {
|
||||
out3_0 = hn::Add(out3_0, helper_out3_0);
|
||||
out3_1 = hn::Add(out3_1, helper_out3_1);
|
||||
}
|
||||
if constexpr (N >= 5) {
|
||||
out4_0 = hn::Add(out4_0, helper_out4_0);
|
||||
out4_1 = hn::Add(out4_1, helper_out4_1);
|
||||
}
|
||||
if constexpr (N >= 6) {
|
||||
out5_0 = hn::Add(out5_0, helper_out5_0);
|
||||
out5_1 = hn::Add(out5_1, helper_out5_1);
|
||||
}
|
||||
if constexpr (N >= 7) {
|
||||
out6_0 = hn::Add(out6_0, helper_out6_0);
|
||||
out6_1 = hn::Add(out6_1, helper_out6_1);
|
||||
}
|
||||
if constexpr (N >= 8) {
|
||||
out7_0 = hn::Add(out7_0, helper_out7_0);
|
||||
out7_1 = hn::Add(out7_1, helper_out7_1);
|
||||
}
|
||||
#endif
|
||||
StoreUpTo8Times2<N>(df, out, i, out0_0, out0_1, out1_0, out1_1, out2_0,
|
||||
out2_1, out3_0, out3_1, out4_0, out4_1, out5_0, out5_1,
|
||||
out6_0, out6_1, out7_0, out7_1);
|
||||
|
||||
i += 2 * NF;
|
||||
}
|
||||
HWY_DASSERT(qkv_dim == i);
|
||||
}
|
||||
|
||||
// Prescales NF rows of out by scale, then multiplies 1 row of V by the
|
||||
// corresponding values in c0 and adds them to the NF rows of out.
|
||||
|
|
|
|||
Loading…
Reference in New Issue