mirror of https://github.com/google/gemma.cpp.git
Improvements to inference using int8 compressed kv's
Multiplication is done using int16*int16 multiplication instructions avoid expensive conversion to f32/bf16 x2 speed on zen3 PiperOrigin-RevId: 888690192
This commit is contained in:
parent
259b757aef
commit
f56d18dd68
|
|
@ -198,6 +198,11 @@ constexpr bool IsInt8() {
|
|||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsInt16() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int16_t>();
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsBF16() {
|
||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();
|
||||
|
|
|
|||
|
|
@ -718,6 +718,7 @@ constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
|
|||
{"flash", AttentionImpl::kFlash},
|
||||
{"flash_transposed_qs", AttentionImpl::kFlashTransposedQs},
|
||||
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
|
||||
{"flash_transposed_qs_int16", AttentionImpl::kFlashTransposedQsInt16},
|
||||
};
|
||||
|
||||
std::string GetAttentionImplName(AttentionImpl impl) {
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ enum class AttentionImpl {
|
|||
kFlash, // Flash Attention (default)
|
||||
kFlashTransposedQs,
|
||||
kFlashTransposedQsBF16,
|
||||
kFlashTransposedQsInt16,
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -558,57 +558,14 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
|
|||
return scale;
|
||||
}
|
||||
|
||||
// Reduces each of x and stores in following lanes of max (tested with float32)
|
||||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||
F reducer) {
|
||||
const DF4 df4;
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
|
||||
HWY_ALIGN T x_transposed[4 * kMaxLanes];
|
||||
hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed);
|
||||
VF x01 =
|
||||
reducer(hn::Load(df, x_transposed), hn::Load(df, x_transposed + kLanes));
|
||||
VF x23 = reducer(hn::Load(df, x_transposed + 2 * kLanes),
|
||||
hn::Load(df, x_transposed + 3 * kLanes));
|
||||
VF x0123 = reducer(x01, x23);
|
||||
hn::Store(x0123, df, x_transposed);
|
||||
|
||||
VF4 result = hn::Load(df4, x_transposed);
|
||||
for (int i = 1; i < kLanes / 4; ++i) {
|
||||
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns vector with 8 lanes. Shouldn't be on architectures with less than 8
|
||||
// lanes per vector.
|
||||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF8 = hn::CappedTag<T, 8>, class VF8 = hn::Vec<DF8>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
|
||||
VF x_5, VF x_6, VF x_7, F reducer) {
|
||||
auto res0123 = Reduce4(df, x_0, x_1, x_2, x_3, reducer);
|
||||
auto res4567 = Reduce4(df, x_4, x_5, x_6, x_7, reducer);
|
||||
|
||||
using DF4 = hn::CappedTag<T, 4>;
|
||||
const DF4 df4;
|
||||
const DF8 df8;
|
||||
HWY_ALIGN T buf[8];
|
||||
hn::Store(res0123, df4, buf);
|
||||
hn::Store(res4567, df4, buf + 4);
|
||||
return hn::Load(df8, buf);
|
||||
}
|
||||
|
||||
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
||||
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
||||
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
||||
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
|
||||
float* HWY_RESTRICT scales) {
|
||||
float* HWY_RESTRICT scales, float* HWY_RESTRICT q_scales_s = nullptr,
|
||||
float max_v_scale = 1.0f) {
|
||||
using DF4 = hn::CappedTag<float, 4>;
|
||||
const DF4 df4;
|
||||
using VF4 = hn::Vec<DF4>;
|
||||
|
|
@ -636,6 +593,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
|||
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
|
||||
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
|
||||
}
|
||||
VF4 local_max = new_max;
|
||||
VF4 old_max_vf = hn::Set(df4, kNegInf);
|
||||
old_max_vf = hn::LoadU(df4, old_max);
|
||||
new_max = hn::Max(new_max, old_max_vf);
|
||||
|
|
@ -679,8 +637,28 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
|||
const VF4 zero4 = hn::Zero(df4);
|
||||
const VF4 one_over_d =
|
||||
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
|
||||
VF4 q_scale;
|
||||
if (q_scales_s != nullptr) {
|
||||
// max_s = exp(local_max - new_max) / old_d_vf
|
||||
VF4 max_s = hn::Mul(one_over_d, hn::Exp(df4, hn::Sub(local_max, new_max)));
|
||||
|
||||
// Output the unquantize scale directly to array memory:
|
||||
// Because we're capping out at 32767 / max_v_scale, the true scale goes up
|
||||
// proportionately
|
||||
hn::Store(hn::Mul(max_s, hn::Set(df4, max_v_scale / 32767.0f)), df4,
|
||||
q_scales_s);
|
||||
|
||||
// multiplier for x = 32767 * exp(new_max - local_max) / max_v_scale
|
||||
auto max_s_gt_0 = hn::Gt(max_s, zero4);
|
||||
float inv_max_v_scale = 1.0f / std::max(max_v_scale, 1e-10f);
|
||||
VF4 mult = hn::Mul(hn::Set(df4, 32767.0f * inv_max_v_scale),
|
||||
hn::Exp(df4, hn::Sub(new_max, local_max)));
|
||||
q_scale = hn::IfThenElse(max_s_gt_0, mult, zero4);
|
||||
} else {
|
||||
q_scale = one_over_d;
|
||||
}
|
||||
HWY_ALIGN float tmp_one_over_d[4];
|
||||
hn::Store(one_over_d, df4, tmp_one_over_d);
|
||||
hn::Store(q_scale, df4, tmp_one_over_d);
|
||||
hn::BlendedStore(old_d_vf, changed_max, df4, old_d);
|
||||
scale = hn::Mul(scale, one_over_d);
|
||||
hn::BlendedStore(scale, changed_max, df4, scales);
|
||||
|
|
@ -713,7 +691,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
|||
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
||||
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
|
||||
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
|
||||
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales) {
|
||||
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales,
|
||||
float* HWY_RESTRICT q_scales_s = nullptr, float max_v_scale = 1.0f) {
|
||||
using DF8 = hn::CappedTag<float, 8>;
|
||||
const DF8 df8;
|
||||
using VF8 = hn::Vec<DF8>;
|
||||
|
|
@ -755,6 +734,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
|||
VF8 one_over_cap = hn::Set(df8, one_over_att_cap);
|
||||
new_max = hn::Mul(cap, hn::Tanh(df8, hn::Mul(new_max, one_over_cap)));
|
||||
}
|
||||
VF8 local_max = new_max;
|
||||
VF8 old_max_vf = hn::Set(df8, kNegInf);
|
||||
old_max_vf = hn::LoadU(df8, old_max);
|
||||
new_max = hn::Max(new_max, old_max_vf);
|
||||
|
|
@ -817,8 +797,22 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
|||
const VF8 zero8 = hn::Zero(df8);
|
||||
const VF8 one_over_d =
|
||||
hn::MaskedDivOr(zero8, non_zero_mask, hn::Set(df8, 1.0f), old_d_vf);
|
||||
VF8 q_scale;
|
||||
if (q_scales_s != nullptr) {
|
||||
VF8 max_s = hn::Mul(one_over_d, hn::Exp(df8, hn::Sub(local_max, new_max)));
|
||||
hn::Store(hn::Mul(max_s, hn::Set(df8, max_v_scale / 32767.0f)), df8,
|
||||
q_scales_s);
|
||||
|
||||
auto max_s_gt_0 = hn::Gt(max_s, zero8);
|
||||
float inv_max_v_scale = 1.0f / std::max(max_v_scale, 1e-10f);
|
||||
VF8 mult = hn::Mul(hn::Set(df8, 32767.0f * inv_max_v_scale),
|
||||
hn::Exp(df8, hn::Sub(new_max, local_max)));
|
||||
q_scale = hn::IfThenElse(max_s_gt_0, mult, zero8);
|
||||
} else {
|
||||
q_scale = one_over_d;
|
||||
}
|
||||
HWY_ALIGN float tmp_one_over_d[8];
|
||||
hn::Store(one_over_d, df8, tmp_one_over_d);
|
||||
hn::Store(q_scale, df8, tmp_one_over_d);
|
||||
hn::BlendedStore(old_d_vf, changed_max, df8, old_d);
|
||||
scale = hn::Mul(scale, one_over_d);
|
||||
hn::BlendedStore(scale, changed_max, df8, scales);
|
||||
|
|
@ -862,7 +856,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
|
|||
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
|
||||
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
|
||||
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t q_group_idx,
|
||||
size_t kNumQueriesPerGroup) {
|
||||
size_t kNumQueriesPerGroup, float* HWY_RESTRICT q_scales_s = nullptr,
|
||||
float max_v_scale = 1.0f) {
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
[[maybe_unused]] constexpr int kSecondHalfAmountOfQueries =
|
||||
kNumQueries - kFirstHalfAmountOfQueries;
|
||||
|
|
@ -870,25 +865,30 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
|
|||
FlashAttentionTileStepAndApplySoftCap4<kFirstHalfAmountOfQueries>(
|
||||
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
|
||||
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
|
||||
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
|
||||
old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s,
|
||||
max_v_scale);
|
||||
} else {
|
||||
#if HWY_MAX_BYTES <= 16
|
||||
FlashAttentionTileStepAndApplySoftCap4<4>(
|
||||
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
|
||||
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
|
||||
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
|
||||
old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s,
|
||||
max_v_scale);
|
||||
FlashAttentionTileStepAndApplySoftCap4<kSecondHalfAmountOfQueries>(
|
||||
df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0,
|
||||
x_6_p1, x_7_p0, x_7_p1,
|
||||
old_max + (q_group_idx + 1) * kNumQueriesPerGroup,
|
||||
old_d + (q_group_idx + 1) * kNumQueriesPerGroup,
|
||||
scales + kNumQueriesPerGroup);
|
||||
scales + kNumQueriesPerGroup,
|
||||
q_scales_s == nullptr ? nullptr : q_scales_s + kNumQueriesPerGroup,
|
||||
max_v_scale);
|
||||
#else
|
||||
FlashAttentionTileStepAndApplySoftCap8<kNumQueries>(
|
||||
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
|
||||
x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1,
|
||||
x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
|
||||
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
|
||||
old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s,
|
||||
max_v_scale);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
@ -998,6 +998,138 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth(
|
|||
}
|
||||
}
|
||||
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16(
|
||||
DF df, const int16_t* HWY_RESTRICT q, const int16_t* HWY_RESTRICT q2,
|
||||
const int8_t* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0,
|
||||
VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1,
|
||||
VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0,
|
||||
VF& sum5_p1, VF& sum6_p0, VF& sum6_p1, VF& sum7_p0, VF& sum7_p1) {
|
||||
using DI16 = hn::ScalableTag<int16_t>;
|
||||
const DI16 di16;
|
||||
using VI16 = hn::Vec<DI16>;
|
||||
using DI32 = hn::Repartition<int32_t, DF>;
|
||||
const DI32 di32;
|
||||
using VI32 = hn::Vec<DI32>;
|
||||
HWY_DASSERT(hn::Lanes(di16) <= gcpp::KVCache::kTileSize);
|
||||
HWY_DASSERT(kNumQueries <= 8);
|
||||
HWY_DASSERT(gcpp::KVCache::kTileSize >= hn::Lanes(df) * 2);
|
||||
|
||||
VI32 isum0_p0 = hn::Zero(di32);
|
||||
VI32 isum0_p1 = hn::Zero(di32);
|
||||
VI32 isum1_p0 = hn::Zero(di32), isum1_p1 = hn::Zero(di32);
|
||||
VI32 isum2_p0 = hn::Zero(di32), isum2_p1 = hn::Zero(di32);
|
||||
VI32 isum3_p0 = hn::Zero(di32), isum3_p1 = hn::Zero(di32);
|
||||
VI32 isum4_p0 = hn::Zero(di32), isum4_p1 = hn::Zero(di32);
|
||||
VI32 isum5_p0 = hn::Zero(di32), isum5_p1 = hn::Zero(di32);
|
||||
VI32 isum6_p0 = hn::Zero(di32), isum6_p1 = hn::Zero(di32);
|
||||
VI32 isum7_p0 = hn::Zero(di32), isum7_p1 = hn::Zero(di32);
|
||||
VI32 isum0_odd_p0 = hn::Zero(di32), isum0_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum1_odd_p0 = hn::Zero(di32), isum1_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum2_odd_p0 = hn::Zero(di32), isum2_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum3_odd_p0 = hn::Zero(di32), isum3_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum4_odd_p0 = hn::Zero(di32), isum4_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum5_odd_p0 = hn::Zero(di32), isum5_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum6_odd_p0 = hn::Zero(di32), isum6_odd_p1 = hn::Zero(di32);
|
||||
VI32 isum7_odd_p0 = hn::Zero(di32), isum7_odd_p1 = hn::Zero(di32);
|
||||
|
||||
const int32_t* q_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q);
|
||||
const int32_t* q2_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q2);
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
constexpr int kSecondHalfAmountOfQueries =
|
||||
kNumQueries - kFirstHalfAmountOfQueries;
|
||||
|
||||
const hn::Repartition<int8_t, DI16> di8;
|
||||
const hn::Half<decltype(di8)> di8_half;
|
||||
for (size_t i = 0; i < qkv_dim / 2; i++) {
|
||||
auto k_dim0 = hn::LoadU(
|
||||
di8_half, k_transposed_tile + (i * 2) * gcpp::KVCache::kTileSize);
|
||||
auto k_dim1 = hn::LoadU(di8_half, k_transposed_tile +
|
||||
(i * 2) * gcpp::KVCache::kTileSize +
|
||||
hn::Lanes(di8_half));
|
||||
auto k_vec1 = hn::PromoteTo(di16, k_dim0);
|
||||
auto k_vec2 = hn::PromoteTo(di16, k_dim1);
|
||||
|
||||
auto accumulate = [&](int32_t q_val, VI32& sum_p0, VI32& sum_p1,
|
||||
VI32& sum_odd_p0, VI32& sum_odd_p1) HWY_ATTR {
|
||||
VI16 q_vec = hn::BitCast(di16, hn::Set(di32, q_val));
|
||||
sum_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_vec, sum_p0,
|
||||
sum_odd_p0);
|
||||
sum_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_vec, sum_p1,
|
||||
sum_odd_p1);
|
||||
};
|
||||
|
||||
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries], isum0_p0, isum0_p1,
|
||||
isum0_odd_p0, isum0_odd_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 1], isum1_p0,
|
||||
isum1_p1, isum1_odd_p0, isum1_odd_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 2], isum2_p0,
|
||||
isum2_p1, isum2_odd_p0, isum2_odd_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 3], isum3_p0,
|
||||
isum3_p1, isum3_odd_p0, isum3_odd_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 0], isum4_p0,
|
||||
isum4_p1, isum4_odd_p0, isum4_odd_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 1], isum5_p0,
|
||||
isum5_p1, isum5_odd_p0, isum5_odd_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 2], isum6_p0,
|
||||
isum6_p1, isum6_odd_p0, isum6_odd_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 3], isum7_p0,
|
||||
isum7_p1, isum7_odd_p0, isum7_odd_p1);
|
||||
}
|
||||
}
|
||||
|
||||
auto convert_to_float = [&](const VI32& sum_p0, const VI32& sum_odd_p0,
|
||||
const VI32& sum_p1, const VI32& sum_odd_p1,
|
||||
VF& out_p0, VF& out_p1) HWY_ATTR {
|
||||
out_p0 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(sum_p0, sum_odd_p0));
|
||||
out_p1 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(sum_p1, sum_odd_p1));
|
||||
};
|
||||
|
||||
convert_to_float(isum0_p0, isum0_odd_p0, isum0_p1, isum0_odd_p1, sum0_p0,
|
||||
sum0_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
convert_to_float(isum1_p0, isum1_odd_p0, isum1_p1, isum1_odd_p1, sum1_p0,
|
||||
sum1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
convert_to_float(isum2_p0, isum2_odd_p0, isum2_p1, isum2_odd_p1, sum2_p0,
|
||||
sum2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
convert_to_float(isum3_p0, isum3_odd_p0, isum3_p1, isum3_odd_p1, sum3_p0,
|
||||
sum3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
convert_to_float(isum4_p0, isum4_odd_p0, isum4_p1, isum4_odd_p1, sum4_p0,
|
||||
sum4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
convert_to_float(isum5_p0, isum5_odd_p0, isum5_p1, isum5_odd_p1, sum5_p0,
|
||||
sum5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
convert_to_float(isum6_p0, isum6_odd_p0, isum6_p1, isum6_odd_p1, sum6_p0,
|
||||
sum6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
convert_to_float(isum7_p0, isum7_odd_p0, isum7_p1, isum7_odd_p1, sum7_p0,
|
||||
sum7_p1);
|
||||
}
|
||||
}
|
||||
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename T>
|
||||
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16(
|
||||
DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2,
|
||||
|
|
@ -1306,6 +1438,47 @@ static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
|
|||
}
|
||||
}
|
||||
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void ApplyQuantizationScale(
|
||||
DF df, const float* HWY_RESTRICT q_scales, int q_group_idx,
|
||||
int kNumQueriesPerGroup, VF& x0_p0, VF& x0_p1, VF& x1_p0, VF& x1_p1,
|
||||
VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0, VF& x4_p1, VF& x5_p0,
|
||||
VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, VF& x7_p1) {
|
||||
auto apply_scale = [&](int group_offset, int query_offset, VF& x_p0,
|
||||
VF& x_p1) HWY_ATTR {
|
||||
int scale_idx =
|
||||
(q_group_idx + group_offset) * kNumQueriesPerGroup + query_offset;
|
||||
VF s = hn::Set(df, q_scales[scale_idx]);
|
||||
x_p0 = hn::Mul(x_p0, s);
|
||||
x_p1 = hn::Mul(x_p1, s);
|
||||
};
|
||||
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
apply_scale(0, 0, x0_p0, x0_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
apply_scale(0, 1, x1_p0, x1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
apply_scale(0, 2, x2_p0, x2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
apply_scale(0, 3, x3_p0, x3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
apply_scale(1, 0, x4_p0, x4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
apply_scale(1, 1, x5_p0, x5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
apply_scale(1, 2, x6_p0, x6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
apply_scale(1, 3, x7_p0, x7_p1);
|
||||
}
|
||||
}
|
||||
|
||||
// Performs tiled flash attention for arbitrary number of queries
|
||||
// It depends on kv being tiled.
|
||||
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
|
||||
|
|
@ -1317,6 +1490,7 @@ static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
|
|||
// as it will be used to figure out when to switch to the next one.
|
||||
// q_T_in_groups_up_to_4 - Span of float* All except last float*
|
||||
// should have (qkv_dim, 4) Last one can have any size up to 4.
|
||||
// q_scales - Span of float of shape (q_count,) used for Queries in int16 format
|
||||
// start_pos_per_query - start position in kv to start attention from ()
|
||||
// last_pos_per_query - last position in kv to attend to (exclusive)
|
||||
// queries_per_timestep - how many queries begin/end on the same timestep
|
||||
|
|
@ -1332,6 +1506,7 @@ template <typename KV_T, typename Q_T>
|
|||
HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
const hwy::Span<const MatPtrT<KV_T>> kvs, int q_count,
|
||||
const hwy::Span<const Q_T * HWY_RESTRICT> q_T_in_groups_up_to_4,
|
||||
const hwy::Span<const float> q_scales,
|
||||
hwy::Span<const size_t> start_pos_per_query,
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
|
||||
|
|
@ -1406,6 +1581,12 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
size_t current_kv_start_offset = 0;
|
||||
size_t current_kv_idx = 0;
|
||||
|
||||
HWY_ALIGN float q_scales_s[8];
|
||||
float* q_scales_s_ptr = nullptr;
|
||||
if constexpr (IsInt16<Q_T>()) {
|
||||
q_scales_s_ptr = q_scales_s;
|
||||
}
|
||||
float max_v_scale = 1.0f;
|
||||
auto inner_loop = [&]<int kNumQueries>(int q_group_idx) HWY_ATTR {
|
||||
int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup);
|
||||
if (position + step_size <= min_start_pos_per_group[loop_idx] ||
|
||||
|
|
@ -1429,6 +1610,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
if (kNumQueries > 4) {
|
||||
q2_group = q_T_in_groups_up_to_4[q_group_idx + 1];
|
||||
}
|
||||
|
||||
if constexpr (IsF32<Q_T>()) {
|
||||
const KV_T* k_transposed_tile = tile_base + pos_in_tile;
|
||||
QDotKTilexUpTo8TransposedKDoubleWidth<kNumQueries>(
|
||||
|
|
@ -1441,10 +1623,16 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
|
||||
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
|
||||
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
} else if constexpr (IsInt16<Q_T>()) {
|
||||
const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2;
|
||||
QDotKTilexUpTo8TransposedKDoubleWidthInt16<kNumQueries>(
|
||||
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
|
||||
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
|
||||
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
} else {
|
||||
static_assert(
|
||||
false,
|
||||
"Query type type not supported, only float and BF16 are supported");
|
||||
static_assert(false,
|
||||
"Query type not supported, only float, BF16, and "
|
||||
"Int16 are supported");
|
||||
}
|
||||
// microscaling
|
||||
// TODO: Change to more generic function to inform if we should use
|
||||
|
|
@ -1461,6 +1649,13 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
|
||||
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
|
||||
}
|
||||
if constexpr (IsInt16<Q_T>()) {
|
||||
ApplyQuantizationScale<kNumQueries>(
|
||||
df, q_scales.data(), q_group_idx, kNumQueriesPerGroup, x_0_p_0,
|
||||
x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1,
|
||||
x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0,
|
||||
x_7_p_1);
|
||||
}
|
||||
|
||||
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
|
||||
constexpr int kSecondHalfAmountOfQueries =
|
||||
|
|
@ -1485,15 +1680,29 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
x_7_p_0, x_7_p_1);
|
||||
}
|
||||
HWY_ALIGN float scales[kNumQueriesPerLoop];
|
||||
// HWY_UNROLL(kNumQueriesPerLoop)
|
||||
|
||||
for (size_t i = 0; i < kNumQueriesPerLoop; ++i) {
|
||||
scales[i] = 1.0f;
|
||||
}
|
||||
|
||||
if constexpr (IsInt16<Q_T>() && kUseMicroScaling) {
|
||||
if (q_group_idx == 0) { // update only when needed
|
||||
const BF16* microscaling_scales_v =
|
||||
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
|
||||
kTileSize + pos_in_tile;
|
||||
const PackedSpan<const BF16> scales_span =
|
||||
MakeConstSpan(microscaling_scales_v, 2 * hn::Lanes(df));
|
||||
VF v_scales_p0, v_scales_p1;
|
||||
Decompress2(df, scales_span, 0, v_scales_p0, v_scales_p1);
|
||||
max_v_scale = hn::ReduceMax(df, hn::Max(v_scales_p0, v_scales_p1));
|
||||
}
|
||||
}
|
||||
|
||||
FlashAttentionTileStepAndApplySoftCap<kNumQueries>(
|
||||
df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
|
||||
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
|
||||
kNumQueriesPerGroup);
|
||||
kNumQueriesPerGroup, q_scales_s_ptr, max_v_scale);
|
||||
if constexpr (kUseMicroScaling) {
|
||||
const BF16* microscaling_scales_v =
|
||||
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
|
||||
|
|
@ -1508,7 +1717,13 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
|
||||
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
|
||||
} else if constexpr (IsBF16<Q_T>()) {
|
||||
} else if constexpr (IsInt16<Q_T>()) {
|
||||
MulByConstAndAddTileUpTo8_BF16_Int16<kNumQueries>(
|
||||
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
|
||||
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx],
|
||||
q_scales_s);
|
||||
} else {
|
||||
MulByConstAndAddTileUpTo8_BF16<kNumQueries>(
|
||||
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
|
||||
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
|
||||
|
|
@ -1558,7 +1773,7 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
|
|||
float* HWY_RESTRICT max_logits) {
|
||||
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
|
||||
return TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
|
||||
kv_t, q_count, q_T_in_groups_up_to_4, {}, start_pos_per_query,
|
||||
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
|
||||
});
|
||||
}
|
||||
|
|
@ -1572,11 +1787,31 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
|||
float* HWY_RESTRICT max_logits) {
|
||||
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
|
||||
return TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
|
||||
kv_t, q_count, q_T_in_groups_up_to_4, {}, start_pos_per_query,
|
||||
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
|
||||
});
|
||||
}
|
||||
|
||||
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
|
||||
hwy::Span<const MatPtr> kvs, int q_count,
|
||||
const hwy::Span<const int16_t* HWY_RESTRICT> q_T_in_groups_up_to_4,
|
||||
const hwy::Span<const float> q_scales,
|
||||
hwy::Span<const size_t> start_pos_per_query,
|
||||
hwy::Span<const size_t> last_pos_per_query, float att_cap,
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
|
||||
float* HWY_RESTRICT max_logits) {
|
||||
for ([[maybe_unused]] auto&& mat : kvs) {
|
||||
HWY_DASSERT(mat.GetType() == Type::kInt8);
|
||||
}
|
||||
auto matptrs = MakeMatPtrVec<int8_t>(kvs);
|
||||
hwy::Span<const MatPtrT<int8_t>> matptrs_span(matptrs.data(), matptrs.size());
|
||||
|
||||
return TileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
matptrs_span, q_count, q_T_in_groups_up_to_4, q_scales,
|
||||
start_pos_per_query, last_pos_per_query, att_cap, att_out,
|
||||
exp_denominator_sums, max_logits);
|
||||
}
|
||||
|
||||
// Implements flash attention for a strip of tiles of size 1, 4 or 8 query
|
||||
// vectors by 2NF positions in K.
|
||||
// It iterates through tiles in K from `params.min_start_pos / 2NF * 2NF` up to
|
||||
|
|
|
|||
|
|
@ -65,6 +65,16 @@ namespace gcpp {
|
|||
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
|
||||
float* HWY_RESTRICT max_logits); \
|
||||
\
|
||||
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( \
|
||||
hwy::Span<const MatPtr> kvs, int q_count, \
|
||||
const hwy::Span<const int16_t* HWY_RESTRICT> q_T_in_groups_up_to_4, \
|
||||
hwy::Span<const float> q_scales, \
|
||||
hwy::Span<const size_t> start_pos_per_query, \
|
||||
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
|
||||
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
|
||||
float* HWY_RESTRICT max_logits); \
|
||||
\
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@
|
|||
#include "gemma/attention.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/flash_attention.h"
|
||||
#include "gemma/tiled_attention.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -416,25 +417,22 @@ void TestTiledFlashAttentionBF16() {
|
|||
ctx.allocator, MatPadding::kPacked);
|
||||
PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim);
|
||||
|
||||
std::vector<BF16> q_float(num_queries_per_timestep * qkv_dim);
|
||||
std::vector<BF16> q_float2(num_queries_per_timestep * qkv_dim);
|
||||
// fill in qs with predictable, synthetic data
|
||||
for (size_t i = 0; i < num_queries_per_timestep; ++i) {
|
||||
for (size_t j = 0; j < qkv_dim; j += 2) {
|
||||
q_float[j * num_queries_per_timestep + i * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 1));
|
||||
q_float[j * num_queries_per_timestep + i * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 2));
|
||||
|
||||
q_float2[j * num_queries_per_timestep + i * 2] =
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
0.01f * (i + num_queries_per_timestep + 1) / (j + 1));
|
||||
q_float2[j * num_queries_per_timestep + i * 2 + 1] =
|
||||
hwy::ConvertScalarTo<BF16>(
|
||||
0.01f * (i + num_queries_per_timestep + 1) / (j + 2));
|
||||
std::vector<float> q_all(num_queries * qkv_dim);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
for (size_t j = 0; j < qkv_dim; ++j) {
|
||||
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
|
||||
}
|
||||
}
|
||||
const BF16* q_T[2] = {q_float.data(), q_float2.data()};
|
||||
std::vector<float*> q_ptrs(num_queries);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
q_ptrs[i] = q_all.data() + i * qkv_dim;
|
||||
}
|
||||
auto [transposed_queries, transposed_queries_ptrs, _] =
|
||||
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(hwy::Span<float*>(q_ptrs),
|
||||
qkv_dim, /*group_size=*/4);
|
||||
hwy::Span<const BF16*> q_T(
|
||||
const_cast<const BF16**>(transposed_queries_ptrs.data()),
|
||||
transposed_queries_ptrs.size());
|
||||
|
||||
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
|
|
@ -465,8 +463,7 @@ void TestTiledFlashAttentionBF16() {
|
|||
}
|
||||
hwy::Span<const MatPtr> kvs(&kv, 1);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
||||
kvs, num_queries, hwy::Span<const BF16*>(q_T, 2),
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
kvs, num_queries, q_T, hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
|
||||
exp_denominator_sums.data(), max_logits.data());
|
||||
|
||||
|
|
@ -482,6 +479,10 @@ void TestTiledFlashAttentionBF16() {
|
|||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
|
||||
for (size_t j = 0; j < qkv_dim; ++j) {
|
||||
if (j == 0) {
|
||||
std::cerr << "att_out[0][" << j << "]=" << att_out.Row(i)[j]
|
||||
<< " gold=" << att_out_gold[i * qkv_dim + j] << "\n";
|
||||
}
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f);
|
||||
}
|
||||
}
|
||||
|
|
@ -577,6 +578,186 @@ void TestTiledFlashAttentionInt8() {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
void TestTiledFlashAttentionInt8BF16() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
|
||||
// tiles size to test the padding logic.
|
||||
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
|
||||
float att_cap = 10.0f;
|
||||
int num_queries = 8;
|
||||
int num_queries_per_timestep = 4;
|
||||
int num_tokens = num_queries / num_queries_per_timestep;
|
||||
int kv_seq_end =
|
||||
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
|
||||
int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
|
||||
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
|
||||
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
|
||||
|
||||
MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
|
||||
// fill in kvs with predictable, synthetic data matching BF16 paired layout
|
||||
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim);
|
||||
|
||||
std::vector<float> q_all(num_queries * qkv_dim);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
|
||||
}
|
||||
}
|
||||
std::vector<float*> q_ptrs(num_queries);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
q_ptrs[i] = q_all.data() + i * qkv_dim;
|
||||
}
|
||||
auto [transposed_queries, transposed_queries_ptrs, _] =
|
||||
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(hwy::Span<float*>(q_ptrs),
|
||||
qkv_dim, /*group_size=*/4);
|
||||
hwy::Span<const BF16*> q_T(
|
||||
const_cast<const BF16**>(transposed_queries_ptrs.data()),
|
||||
transposed_queries_ptrs.size());
|
||||
|
||||
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
|
||||
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
|
||||
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
|
||||
std::vector<float> max_logits(num_queries_rounded_to_laness);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(i),
|
||||
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
|
||||
exp_denominator_sums[i] = 0.0f;
|
||||
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
|
||||
}
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
|
||||
start_pos_per_query.reserve(num_queries);
|
||||
last_pos_per_query.reserve(num_queries);
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
ssize_t query_last_pos = kv_seq_end + token_idx;
|
||||
ssize_t query_start_pos =
|
||||
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
|
||||
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
|
||||
++q_head_idx) {
|
||||
start_pos_per_query.push_back(query_start_pos);
|
||||
last_pos_per_query.push_back(query_last_pos);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::Span<const MatPtr> kvs(&kv, 1);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
||||
kvs, num_queries, q_T, hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
|
||||
exp_denominator_sums.data(), max_logits.data());
|
||||
|
||||
PrintMatPtr(att_out);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestTiledFlashAttentionInt8Int16() {
|
||||
int qkv_dim = 64;
|
||||
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
|
||||
// tiles size to test the padding logic.
|
||||
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
|
||||
float att_cap = 10.0f;
|
||||
int num_queries = 8;
|
||||
int num_queries_per_timestep = 4;
|
||||
int num_tokens = num_queries / num_queries_per_timestep;
|
||||
int kv_seq_end =
|
||||
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
|
||||
ThreadingArgs threading_args;
|
||||
ThreadingContext ctx(threading_args);
|
||||
|
||||
int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
|
||||
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
|
||||
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
|
||||
|
||||
MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
|
||||
// fill in kvs with predictable, synthetic data matching BF16 paired layout
|
||||
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim);
|
||||
|
||||
std::vector<float> q_all(num_queries * qkv_dim);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
|
||||
}
|
||||
}
|
||||
std::vector<float*> q_ptrs(num_queries);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
q_ptrs[i] = q_all.data() + i * qkv_dim;
|
||||
}
|
||||
auto [transposed_queries, transposed_queries_ptrs, q_scales] =
|
||||
TransposeQueriesToGroupsOfNBF16orInt16<int16_t>(
|
||||
hwy::Span<float*>(q_ptrs), qkv_dim, /*group_size=*/4);
|
||||
hwy::Span<const int16_t*> q_T(
|
||||
const_cast<const int16_t**>(transposed_queries_ptrs.data()),
|
||||
transposed_queries_ptrs.size());
|
||||
|
||||
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
|
||||
ctx.allocator, MatPadding::kPacked);
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
|
||||
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
|
||||
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
|
||||
std::vector<float> max_logits(num_queries_rounded_to_laness);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
hwy::ZeroBytes(att_out.Row(i),
|
||||
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
|
||||
exp_denominator_sums[i] = 0.0f;
|
||||
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
|
||||
}
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
|
||||
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
|
||||
start_pos_per_query.reserve(num_queries);
|
||||
last_pos_per_query.reserve(num_queries);
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
ssize_t query_last_pos = kv_seq_end + token_idx;
|
||||
ssize_t query_start_pos =
|
||||
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
|
||||
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
|
||||
++q_head_idx) {
|
||||
start_pos_per_query.push_back(query_start_pos);
|
||||
last_pos_per_query.push_back(query_last_pos);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::Span<const MatPtr> kvs(&kv, 1);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
|
||||
kvs, num_queries, q_T, q_scales,
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
|
||||
exp_denominator_sums.data(), max_logits.data());
|
||||
|
||||
PrintMatPtr(att_out);
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
@ -590,6 +771,8 @@ HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
|
|||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttention);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionBF16);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8BF16);
|
||||
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8Int16);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -82,7 +82,8 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
|
|||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
|
||||
if (activations.attention_impl == AttentionImpl::kFlashTransposedQs ||
|
||||
activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16 ||
|
||||
activations.attention_impl == AttentionImpl::kFlashTransposedQsInt16) {
|
||||
TiledAttention(
|
||||
activations.attention_impl, num_tokens, layer_idx, layer,
|
||||
activations.attention, qbatch, env,
|
||||
|
|
|
|||
|
|
@ -76,24 +76,37 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
|
|||
const RuntimeConfig& runtime_config,
|
||||
const Allocator& allocator)
|
||||
: allocator_(allocator) {
|
||||
// clang-format off
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs ||
|
||||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsInt16 ||
|
||||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
|| ((runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs
|
||||
) &&
|
||||
hwy::IsSame<KV_t, BF16>())) {
|
||||
) {
|
||||
// clang-format on
|
||||
const size_t num_tiles =
|
||||
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
|
||||
tiled_seq_len = num_tiles * kTileSize;
|
||||
Type kv_cache_type;
|
||||
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
|| hwy::IsSame<KV_t, BF16>()) {
|
||||
) {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16);
|
||||
} else if (runtime_config.attention_impl ==
|
||||
AttentionImpl::kFlashTransposedQsInt16) {
|
||||
if (runtime_config.kv_cache_type.has_value() &&
|
||||
runtime_config.kv_cache_type.value() != Type::kInt8) {
|
||||
HWY_WARN(
|
||||
"You are have set kv_cache_type to %s, but you are using "
|
||||
"FlashTransposedQsInt16 attention implementation which only "
|
||||
"supports Int8. kv_cache_type will be set to Int8.",
|
||||
runtime_config.kv_cache_type.value());
|
||||
}
|
||||
kv_cache_type = Type::kInt8;
|
||||
} else {
|
||||
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
|
||||
}
|
||||
|
||||
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
|
||||
if (kv_cache_type == Type::kInt8) {
|
||||
// microscaling
|
||||
tile_length += 2 * sizeof(BF16) * kTileSize;
|
||||
}
|
||||
auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -127,6 +129,10 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
// Apply positional encodings and store K/V in tiled format.
|
||||
hwy::Divisor div_kv_heads(kv_heads);
|
||||
|
||||
bool is_transposed_qs =
|
||||
attention_impl == AttentionImpl::kFlashTransposedQsBF16
|
||||
|| attention_impl == AttentionImpl::kFlashTransposedQsInt16;
|
||||
|
||||
hn::ScalableTag<float> df;
|
||||
static hwy::Divisor tile_size_divisor(KVCache::kTileSize);
|
||||
ParallelFor(
|
||||
|
|
@ -249,7 +255,7 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
v_cache_values = v_buf;
|
||||
}
|
||||
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
if (is_transposed_qs) {
|
||||
const int in_tile_idx_mod_2 = in_tile_idx % 2;
|
||||
for (int dim = 0; dim < qkv_dim; dim += 2) {
|
||||
const int dim_mod_2 = dim % 2;
|
||||
|
|
@ -280,7 +286,7 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
}
|
||||
Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls,
|
||||
tile_packed_span, 0);
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
if (is_transposed_qs) {
|
||||
Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls,
|
||||
tile_packed_span, qkv_dim * KVCache::kTileSize);
|
||||
}
|
||||
|
|
@ -289,23 +295,6 @@ static HWY_INLINE void ComputeQKVTransposedTile(
|
|||
});
|
||||
}
|
||||
|
||||
// TODO: optimize with gathers
|
||||
// This format might change in the future, when kernel will be updated to
|
||||
// support more than 8 queries.
|
||||
// Input (num_queries, qkv_dim)
|
||||
// Output (qkv_dim, num_queries)
|
||||
void TransposeQ(const MatPtrT<float>& queries,
|
||||
hwy::Span<float> transposed_queries_span) {
|
||||
const size_t qkv_dim = queries.Cols();
|
||||
const size_t num_queries = queries.Rows();
|
||||
HWY_ASSERT(transposed_queries_span.size() == num_queries * qkv_dim);
|
||||
for (size_t i = 0; i < qkv_dim; i++) {
|
||||
for (size_t j = 0; j < num_queries; ++j) {
|
||||
transposed_queries_span[i * num_queries + j] = queries.Row(j)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transposes queries
|
||||
// Input: vector of pointers to subsequent queries. (allows for arbitrary
|
||||
// strides)
|
||||
|
|
@ -375,6 +364,144 @@ std::pair<AlignedFloatVector, std::vector<float*>> TransposeQueriesToGroupsOf4(
|
|||
std::move(transposed_queries_ptrs));
|
||||
}
|
||||
|
||||
template <typename OutT>
|
||||
static HWY_INLINE void TransposeStridedQueriesBF16orInt16(
|
||||
hwy::Span<const float*> queries, int qkv_dim,
|
||||
hwy::Span<OutT> transposed_queries, hwy::Span<float> q_scales) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
const DF df;
|
||||
using VF = hn::Vec<DF>;
|
||||
// doubles to avoid moving between int/float domains when gathering
|
||||
using DF64 = hn::ScalableTag<double>;
|
||||
const DF64 dd64;
|
||||
using DI64 = hn::ScalableTag<int64_t>;
|
||||
const DI64 di64;
|
||||
using VI64 = hn::Vec<DI64>;
|
||||
auto d_out = hn::Rebind<OutT, decltype(df)>();
|
||||
const size_t lanes = hn::Lanes(df);
|
||||
const size_t half_lanes = lanes / 2;
|
||||
const size_t num_queries = queries.size();
|
||||
const size_t num_numbers_to_gather = num_queries * 2;
|
||||
const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, half_lanes);
|
||||
const size_t num_scales_rounded_up =
|
||||
hwy::RoundUpTo(num_numbers_to_gather, lanes);
|
||||
|
||||
// We store scales twice so we will be able to just load them without a need
|
||||
// to duplicate for multiplication
|
||||
AlignedFloatVector inverted_q_scales_doubled(num_scales_rounded_up);
|
||||
|
||||
if constexpr (IsInt16<OutT>()) {
|
||||
// compute microscales
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
float max_abs = AbsMaxOfSpan(hwy::Span<const float>(queries[i], qkv_dim));
|
||||
float scale = max_abs == 0.0f ? 1.0f : 32767.0f / max_abs;
|
||||
inverted_q_scales_doubled[2 * i] = scale;
|
||||
inverted_q_scales_doubled[2 * i + 1] = scale;
|
||||
q_scales[i] = 1.0f / scale;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t, hwy::AlignedAllocator<int64_t>> query_offsets(
|
||||
num_queries_rounded_up);
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
query_offsets[i] = (queries[i] - queries[0]) / 2;
|
||||
}
|
||||
for (size_t i = num_queries; i < num_queries_rounded_up; ++i) {
|
||||
// last offset is the same so gather doesn't read out of bounds
|
||||
query_offsets[i] = query_offsets[num_queries > 0 ? num_queries - 1 : 0];
|
||||
}
|
||||
|
||||
const double* queries_0_double = HWY_RCAST_ALIGNED(const double*, queries[0]);
|
||||
|
||||
// Lambda to handle the scaling and demotion for Int16 types.
|
||||
auto process_values = [&]() HWY_ATTR {
|
||||
if constexpr (IsInt16<OutT>()) {
|
||||
return [&](VF& x, size_t j) HWY_ATTR {
|
||||
VF scales = hn::Load(df, inverted_q_scales_doubled.data() + j * 2);
|
||||
x = hn::Mul(x, scales);
|
||||
return hn::DemoteTo(d_out, hn::NearestInt(x));
|
||||
};
|
||||
} else {
|
||||
return [&](VF& x, size_t j) HWY_ATTR { return hn::DemoteTo(d_out, x); };
|
||||
}
|
||||
}();
|
||||
|
||||
for (size_t i = 0; i < qkv_dim; i += 2) {
|
||||
size_t j = 0;
|
||||
if (num_queries >= half_lanes) {
|
||||
for (; j <= num_queries - half_lanes; j += half_lanes) {
|
||||
const VI64 offsets = hn::LoadU(di64, query_offsets.data() + j);
|
||||
auto x64 = hn::GatherIndex(dd64, queries_0_double + i / 2, offsets);
|
||||
VF x = hn::BitCast(df, x64);
|
||||
if constexpr (IsInt16<OutT>()) {
|
||||
auto demoted = process_values(x, j);
|
||||
hn::Store(demoted, d_out,
|
||||
transposed_queries.data() + i * num_queries + j * 2);
|
||||
} else if constexpr (IsBF16<OutT>()) {
|
||||
auto demoted = hn::DemoteTo(d_out, x);
|
||||
hn::Store(demoted, d_out,
|
||||
transposed_queries.data() + i * num_queries + j * 2);
|
||||
} else {
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (j < num_queries) {
|
||||
const VI64 offsets = hn::LoadU(di64, query_offsets.data() + j);
|
||||
auto x64 = hn::GatherIndex(dd64, queries_0_double + i / 2, offsets);
|
||||
VF x = hn::BitCast(df, x64);
|
||||
if constexpr (IsInt16<OutT>()) {
|
||||
auto demoted = process_values(x, j);
|
||||
hn::StoreN(demoted, d_out,
|
||||
transposed_queries.data() + i * num_queries + j * 2,
|
||||
num_numbers_to_gather - j * 2);
|
||||
} else if constexpr (IsBF16<OutT>()) {
|
||||
auto demoted = hn::DemoteTo(d_out, x);
|
||||
hn::StoreN(demoted, d_out,
|
||||
transposed_queries.data() + i * num_queries + j * 2,
|
||||
num_numbers_to_gather - j * 2);
|
||||
} else {
|
||||
static_assert(false, "Unsupported type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Transposed queries data, vector of pointers to transposed queries, vector of
|
||||
// scales
|
||||
template <typename OutT>
|
||||
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, std::vector<OutT*>,
|
||||
AlignedFloatVector>
|
||||
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs,
|
||||
int qkv_dim, size_t group_size) {
|
||||
size_t num_queries = queries_ptrs.size();
|
||||
size_t num_groups = hwy::DivCeil(num_queries, group_size);
|
||||
std::vector<OutT, hwy::AlignedAllocator<OutT>> transposed_queries(
|
||||
num_groups * group_size * qkv_dim);
|
||||
std::vector<OutT*> transposed_queries_ptrs;
|
||||
AlignedFloatVector q_scales(num_groups * 4);
|
||||
for (size_t group_idx = 0; group_idx < num_groups; ++group_idx) {
|
||||
size_t current_group_size =
|
||||
std::min(group_size, num_queries - group_idx * group_size);
|
||||
transposed_queries_ptrs.push_back(transposed_queries.data() +
|
||||
group_idx * qkv_dim * group_size);
|
||||
TransposeStridedQueriesBF16orInt16(
|
||||
hwy::Span<const float*>(
|
||||
const_cast<const float**>(queries_ptrs.data() +
|
||||
group_idx * group_size),
|
||||
current_group_size),
|
||||
qkv_dim,
|
||||
hwy::Span<OutT>(transposed_queries_ptrs.back(),
|
||||
qkv_dim * current_group_size),
|
||||
hwy::Span<float>(q_scales.data() + group_idx * group_size,
|
||||
current_group_size));
|
||||
}
|
||||
return std::make_tuple(std::move(transposed_queries),
|
||||
std::move(transposed_queries_ptrs),
|
||||
std::move(q_scales));
|
||||
}
|
||||
|
||||
std::pair<AlignedBF16Vector, std::vector<BF16*>>
|
||||
TransposeTransposedQueriesAndPackIntoBF16(hwy::Span<float*> queries_ptrs,
|
||||
int qkv_dim, int num_queries) {
|
||||
|
|
@ -537,9 +664,6 @@ void LocalAttentionForAllHeadsTokensAndBatch(
|
|||
hwy::Span<float*> queries_ptrs_span(queries_ptrs.data(),
|
||||
queries_ptrs.size());
|
||||
|
||||
auto [transposed_queries, transposed_queries_ptrs] =
|
||||
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
|
||||
|
||||
MatStorageT<float>& att_out =
|
||||
activations.sub_task_att_out->at(task_idx);
|
||||
AlignedFloatVector& exp_denominator_sums =
|
||||
|
|
@ -604,23 +728,37 @@ void LocalAttentionForAllHeadsTokensAndBatch(
|
|||
last_pos_per_query.push_back(query_last_context_pos);
|
||||
}
|
||||
}
|
||||
|
||||
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
|
||||
// pack transposed queries into BF16
|
||||
hwy::Span<float*> queries_span(transposed_queries_ptrs.data(),
|
||||
transposed_queries_ptrs.size());
|
||||
auto [_, transposed_queries_ptrs_bf16] =
|
||||
TransposeTransposedQueriesAndPackIntoBF16(queries_span, qkv_dim,
|
||||
num_queries);
|
||||
hwy::Span<const BF16*> queries_span_bf16(
|
||||
const_cast<const BF16**>(transposed_queries_ptrs_bf16.data()),
|
||||
transposed_queries_ptrs_bf16.size());
|
||||
auto [transposed_queries, transposed_queries_ptrs, _] =
|
||||
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(
|
||||
queries_ptrs_span, qkv_dim, /*group_size=*/4);
|
||||
hwy::Span<const BF16*> queries_span(
|
||||
const_cast<const BF16**>(transposed_queries_ptrs.data()),
|
||||
transposed_queries_ptrs.size());
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
|
||||
kv_ptrs, num_queries, queries_span_bf16,
|
||||
kv_ptrs, num_queries, queries_span,
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query),
|
||||
activations.config.att_cap, att_out, exp_denominator_sums.data(),
|
||||
max_logits.data());
|
||||
} else if (attention_impl == AttentionImpl::kFlashTransposedQsInt16) {
|
||||
auto [transposed_queries, transposed_queries_ptrs, q_scales] =
|
||||
TransposeQueriesToGroupsOfNBF16orInt16<int16_t>(
|
||||
queries_ptrs_span, qkv_dim, /*group_size=*/4);
|
||||
hwy::Span<const int16_t*> queries_span(
|
||||
const_cast<const int16_t**>(transposed_queries_ptrs.data()),
|
||||
transposed_queries_ptrs.size());
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
|
||||
kv_ptrs, num_queries, queries_span, q_scales,
|
||||
hwy::Span<const size_t>(start_pos_per_query),
|
||||
hwy::Span<const size_t>(last_pos_per_query),
|
||||
activations.config.att_cap, att_out, exp_denominator_sums.data(),
|
||||
max_logits.data());
|
||||
} else {
|
||||
auto [transposed_queries, transposed_queries_ptrs] =
|
||||
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
|
||||
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
|
||||
kv_ptrs, num_queries,
|
||||
hwy::Span<const float*>(
|
||||
|
|
@ -712,8 +850,9 @@ void TiledAttention(AttentionImpl attention_impl, size_t num_tokens,
|
|||
ComputeQKVTransposedTile<BF16>(num_tokens, layer_idx, layer, attention_impl,
|
||||
activations, qbatch, flags, env);
|
||||
} else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kF32) {
|
||||
ComputeQKVTransposedTile<KV_t>(num_tokens, layer_idx, layer, attention_impl,
|
||||
activations, qbatch, flags, env);
|
||||
ComputeQKVTransposedTile<float>(num_tokens, layer_idx, layer,
|
||||
attention_impl, activations, qbatch, flags,
|
||||
env);
|
||||
} else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() ==
|
||||
Type::kInt8) {
|
||||
ComputeQKVTransposedTile<int8_t>(num_tokens, layer_idx, layer,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,12 @@ namespace gcpp {
|
|||
const size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
AttentionActivationsPtrs& activations, QBatch& qbatch, \
|
||||
ThreadingContext& ctx); \
|
||||
\
|
||||
template <typename OutT> \
|
||||
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \
|
||||
std::vector<OutT*>, AlignedFloatVector> \
|
||||
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \
|
||||
int qkv_dim, size_t group_size); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
186
ops/ops-inl.h
186
ops/ops-inl.h
|
|
@ -1078,6 +1078,148 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8(
|
|||
HWY_DASSERT(qkv_dim == i);
|
||||
}
|
||||
|
||||
// Specialized version for BF16 models that uses int16 quantization for V.
|
||||
template <int32_t N, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16_Int16(
|
||||
DF df, const float* HWY_RESTRICT scales, const VF& c0_p0, const VF& c0_p1,
|
||||
const VF& c1_p0, const VF& c1_p1, const VF& c2_p0, const VF& c2_p1,
|
||||
const VF& c3_p0, const VF& c3_p1, const VF& c4_p0, const VF& c4_p1,
|
||||
const VF& c5_p0, const VF& c5_p1, const VF& c6_p0, const VF& c6_p1,
|
||||
const VF& c7_p0, const VF& c7_p1, const int8_t* HWY_RESTRICT v_tile,
|
||||
MatPtrT<float>& out, const float* HWY_RESTRICT q_scales_s) {
|
||||
static_assert(N <= 8);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const size_t qkv_dim = out.Cols();
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
|
||||
using DI16 = hn::Repartition<int16_t, DF>;
|
||||
const DI16 di16;
|
||||
const auto di16_half = hn::Half<DI16>();
|
||||
using DI32 = hn::Repartition<int32_t, DF>;
|
||||
const DI32 di32;
|
||||
using VI16 = hn::Vec<DI16>;
|
||||
using VI32 = hn::Vec<DI32>;
|
||||
using DI8 = hn::Repartition<int8_t, DF>;
|
||||
const hn::Half<DI8> di8_half;
|
||||
HWY_LANES_CONSTEXPR size_t kInt16Lanes = hn::Lanes(di16);
|
||||
|
||||
HWY_ALIGN int16_t cs_i16[N * kMaxLanes * 2];
|
||||
|
||||
auto quantize_s_and_store = [&](int j, const VF& p0, const VF& p1) HWY_ATTR {
|
||||
auto i0 =
|
||||
hn::OrderedDemote2To(di16, hn::NearestInt(p0), hn::NearestInt(p1));
|
||||
hn::Store(i0, di16, cs_i16 + j * kMaxLanes * 2);
|
||||
};
|
||||
|
||||
quantize_s_and_store(0, c0_p0, c0_p1);
|
||||
if constexpr (N >= 2) quantize_s_and_store(1, c1_p0, c1_p1);
|
||||
if constexpr (N >= 3) quantize_s_and_store(2, c2_p0, c2_p1);
|
||||
if constexpr (N >= 4) quantize_s_and_store(3, c3_p0, c3_p1);
|
||||
if constexpr (N >= 5) quantize_s_and_store(4, c4_p0, c4_p1);
|
||||
if constexpr (N >= 6) quantize_s_and_store(5, c5_p0, c5_p1);
|
||||
if constexpr (N >= 7) quantize_s_and_store(6, c6_p0, c6_p1);
|
||||
if constexpr (N >= 8) quantize_s_and_store(7, c7_p0, c7_p1);
|
||||
|
||||
size_t i = 0;
|
||||
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
|
||||
while (i + 2 * NF <= qkv_dim) {
|
||||
VI32 acc0_0 = hn::Zero(di32), acc0_1 = hn::Zero(di32);
|
||||
VI32 acc1_0 = hn::Zero(di32), acc1_1 = hn::Zero(di32);
|
||||
VI32 acc2_0 = hn::Zero(di32), acc2_1 = hn::Zero(di32);
|
||||
VI32 acc3_0 = hn::Zero(di32), acc3_1 = hn::Zero(di32);
|
||||
VI32 acc4_0 = hn::Zero(di32), acc4_1 = hn::Zero(di32);
|
||||
VI32 acc5_0 = hn::Zero(di32), acc5_1 = hn::Zero(di32);
|
||||
VI32 acc6_0 = hn::Zero(di32), acc6_1 = hn::Zero(di32);
|
||||
VI32 acc7_0 = hn::Zero(di32), acc7_1 = hn::Zero(di32);
|
||||
|
||||
VI32 acc0_o_0 = hn::Zero(di32), acc0_o_1 = hn::Zero(di32);
|
||||
VI32 acc1_o_0 = hn::Zero(di32), acc1_o_1 = hn::Zero(di32);
|
||||
VI32 acc2_o_0 = hn::Zero(di32), acc2_o_1 = hn::Zero(di32);
|
||||
VI32 acc3_o_0 = hn::Zero(di32), acc3_o_1 = hn::Zero(di32);
|
||||
VI32 acc4_o_0 = hn::Zero(di32), acc4_o_1 = hn::Zero(di32);
|
||||
VI32 acc5_o_0 = hn::Zero(di32), acc5_o_1 = hn::Zero(di32);
|
||||
VI32 acc6_o_0 = hn::Zero(di32), acc6_o_1 = hn::Zero(di32);
|
||||
VI32 acc7_o_0 = hn::Zero(di32), acc7_o_1 = hn::Zero(di32);
|
||||
|
||||
for (int lane = 0; lane < NF; ++lane) {
|
||||
VI16 vi_first8, vi_next8;
|
||||
|
||||
const int8_t* v_ptr = v_tile + 2 * qkv_dim * lane + i * 2;
|
||||
|
||||
auto v8_t0 = hn::LoadU(di8_half, v_ptr);
|
||||
auto v16_t0 = hn::PromoteTo(di16, v8_t0);
|
||||
|
||||
auto v8_t1 = hn::LoadU(di8_half, v_ptr + kInt16Lanes);
|
||||
auto v16_t1 = hn::PromoteTo(di16, v8_t1);
|
||||
|
||||
vi_first8 = v16_t0;
|
||||
vi_next8 = v16_t1;
|
||||
|
||||
auto mul_acc = [&](int j, VI32& a0, VI32& a_o0, VI32& a1,
|
||||
VI32& a_o1) HWY_ATTR {
|
||||
int16_t s0 = cs_i16[2 * lane + j * kMaxLanes * 2];
|
||||
int16_t s1 = cs_i16[2 * lane + 1 + j * kMaxLanes * 2];
|
||||
|
||||
int32_t s01;
|
||||
hwy::CopySameSize(&s0, reinterpret_cast<int16_t*>(&s01));
|
||||
hwy::CopySameSize(&s1, reinterpret_cast<int16_t*>(&s01) + 1);
|
||||
VI16 sj = hn::BitCast(di16, hn::Set(di32, s01));
|
||||
|
||||
a0 = hn::ReorderWidenMulAccumulate(di32, vi_first8, sj, a0, a_o0);
|
||||
a1 = hn::ReorderWidenMulAccumulate(di32, vi_next8, sj, a1, a_o1);
|
||||
};
|
||||
|
||||
mul_acc(0, acc0_0, acc0_o_0, acc0_1, acc0_o_1);
|
||||
if constexpr (N >= 2) mul_acc(1, acc1_0, acc1_o_0, acc1_1, acc1_o_1);
|
||||
if constexpr (N >= 3) mul_acc(2, acc2_0, acc2_o_0, acc2_1, acc2_o_1);
|
||||
if constexpr (N >= 4) mul_acc(3, acc3_0, acc3_o_0, acc3_1, acc3_o_1);
|
||||
if constexpr (N >= 5) mul_acc(4, acc4_0, acc4_o_0, acc4_1, acc4_o_1);
|
||||
if constexpr (N >= 6) mul_acc(5, acc5_0, acc5_o_0, acc5_1, acc5_o_1);
|
||||
if constexpr (N >= 7) mul_acc(6, acc6_0, acc6_o_0, acc6_1, acc6_o_1);
|
||||
if constexpr (N >= 8) mul_acc(7, acc7_0, acc7_o_0, acc7_1, acc7_o_1);
|
||||
}
|
||||
|
||||
auto convert_and_add = [&](int j, VI32& a0, VI32& a_o0, VI32& a1,
|
||||
VI32& a_o1) HWY_ATTR {
|
||||
VF f0 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(a0, a_o0));
|
||||
VF f1 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(a1, a_o1));
|
||||
|
||||
VF o0 = hn::Load(df, out.Row(j) + i);
|
||||
VF o1 = hn::Load(df, out.Row(j) + i + NF);
|
||||
|
||||
VF scale_old = hn::Set(df, scales[j]);
|
||||
o0 = hn::Mul(o0, scale_old);
|
||||
o1 = hn::Mul(o1, scale_old);
|
||||
|
||||
VF scale_new = hn::Set(df, q_scales_s[j]);
|
||||
o0 = hn::MulAdd(f0, scale_new, o0);
|
||||
o1 = hn::MulAdd(f1, scale_new, o1);
|
||||
|
||||
hn::Store(o0, df, out.Row(j) + i);
|
||||
hn::Store(o1, df, out.Row(j) + i + NF);
|
||||
};
|
||||
|
||||
convert_and_add(0, acc0_0, acc0_o_0, acc0_1, acc0_o_1);
|
||||
if constexpr (N >= 2)
|
||||
convert_and_add(1, acc1_0, acc1_o_0, acc1_1, acc1_o_1);
|
||||
if constexpr (N >= 3)
|
||||
convert_and_add(2, acc2_0, acc2_o_0, acc2_1, acc2_o_1);
|
||||
if constexpr (N >= 4)
|
||||
convert_and_add(3, acc3_0, acc3_o_0, acc3_1, acc3_o_1);
|
||||
if constexpr (N >= 5)
|
||||
convert_and_add(4, acc4_0, acc4_o_0, acc4_1, acc4_o_1);
|
||||
if constexpr (N >= 6)
|
||||
convert_and_add(5, acc5_0, acc5_o_0, acc5_1, acc5_o_1);
|
||||
if constexpr (N >= 7)
|
||||
convert_and_add(6, acc6_0, acc6_o_0, acc6_1, acc6_o_1);
|
||||
if constexpr (N >= 8)
|
||||
convert_and_add(7, acc7_0, acc7_o_0, acc7_1, acc7_o_1);
|
||||
|
||||
i += 2 * NF;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
|
||||
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
|
||||
DF df, const float* HWY_RESTRICT scales, VF c0_p0, VF c0_p1, VF c1_p0,
|
||||
|
|
@ -1600,6 +1742,50 @@ MatStorageT<T> AvgPool4x4(MatStorageT<T>& input, const Allocator& allocator) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// Reduces each of x and stores in following lanes of max (tested with float32)
|
||||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||
F reducer) {
|
||||
const DF4 df4;
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
|
||||
HWY_ALIGN T x_transposed[4 * kMaxLanes];
|
||||
hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed);
|
||||
VF x01 =
|
||||
reducer(hn::Load(df, x_transposed), hn::Load(df, x_transposed + kLanes));
|
||||
VF x23 = reducer(hn::Load(df, x_transposed + 2 * kLanes),
|
||||
hn::Load(df, x_transposed + 3 * kLanes));
|
||||
VF x0123 = reducer(x01, x23);
|
||||
hn::Store(x0123, df, x_transposed);
|
||||
|
||||
VF4 result = hn::Load(df4, x_transposed);
|
||||
for (int i = 1; i < kLanes / 4; ++i) {
|
||||
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns vector with 8 lanes. Shouldn't be used on architectures
|
||||
// with less than 8 lanes per vector.
|
||||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF8 = hn::CappedTag<T, 8>, class VF8 = hn::Vec<DF8>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
|
||||
VF x_5, VF x_6, VF x_7, F reducer) {
|
||||
auto res0123 = Reduce4(df, x_0, x_1, x_2, x_3, reducer);
|
||||
auto res4567 = Reduce4(df, x_4, x_5, x_6, x_7, reducer);
|
||||
|
||||
using DF4 = hn::CappedTag<T, 4>;
|
||||
const DF4 df4;
|
||||
const DF8 df8;
|
||||
HWY_ALIGN T buf[8];
|
||||
hn::Store(res0123, df4, buf);
|
||||
hn::Store(res4567, df4, buf + 4);
|
||||
return hn::Load(df8, buf);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
24
util/mat.h
24
util/mat.h
|
|
@ -481,6 +481,16 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<MatPtrT<T>> MakeMatPtrVec(hwy::Span<const MatPtr> base) {
|
||||
std::vector<MatPtrT<T>> matptrs;
|
||||
matptrs.reserve(base.size());
|
||||
for (auto&& mat : base) {
|
||||
matptrs.emplace_back(mat);
|
||||
}
|
||||
return matptrs;
|
||||
}
|
||||
|
||||
// Calls 'func' with a span of MatPtrT<T> for all elements in `base`.
|
||||
// T is dynamic type, read from base. It is assumed that all elements in `base`
|
||||
// have the same type.
|
||||
|
|
@ -491,25 +501,17 @@ decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
|
|||
for ([[maybe_unused]] auto&& mat : base) {
|
||||
HWY_DASSERT(mat.GetType() == type);
|
||||
}
|
||||
auto make_matptr_vec = [&base](auto element) {
|
||||
std::vector<MatPtrT<decltype(element)>> matptrs;
|
||||
matptrs.reserve(base.size());
|
||||
for (auto&& mat : base) {
|
||||
matptrs.emplace_back(mat);
|
||||
}
|
||||
return matptrs;
|
||||
};
|
||||
if (type == Type::kF32) {
|
||||
auto matptrs = make_matptr_vec(float{});
|
||||
auto matptrs = MakeMatPtrVec<float>(base);
|
||||
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
|
||||
matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else if (type == Type::kBF16) {
|
||||
auto matptrs = make_matptr_vec(BF16{});
|
||||
auto matptrs = MakeMatPtrVec<BF16>(base);
|
||||
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else if (type == Type::kInt8) {
|
||||
auto matptrs = make_matptr_vec(int8_t{});
|
||||
auto matptrs = MakeMatPtrVec<int8_t>(base);
|
||||
hwy::Span<const MatPtrT<int8_t>> matptrs_span(matptrs.data(),
|
||||
matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
|
|
|
|||
Loading…
Reference in New Issue