Improvements to inference using int8 compressed kv's

Multiplication is done using int16*int16 multiplication instructions avoid expensive conversion to f32/bf16
x2 speed on zen3

PiperOrigin-RevId: 888690192
This commit is contained in:
Krzysztof Rymski 2026-03-24 08:50:48 -07:00 committed by Copybara-Service
parent 259b757aef
commit f56d18dd68
12 changed files with 912 additions and 130 deletions

View File

@ -198,6 +198,11 @@ constexpr bool IsInt8() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int8_t>();
}
template <typename Packed>
constexpr bool IsInt16() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, int16_t>();
}
template <typename Packed>
constexpr bool IsBF16() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, BF16>();

View File

@ -718,6 +718,7 @@ constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
{"flash", AttentionImpl::kFlash},
{"flash_transposed_qs", AttentionImpl::kFlashTransposedQs},
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
{"flash_transposed_qs_int16", AttentionImpl::kFlashTransposedQsInt16},
};
std::string GetAttentionImplName(AttentionImpl impl) {

View File

@ -99,6 +99,7 @@ enum class AttentionImpl {
kFlash, // Flash Attention (default)
kFlashTransposedQs,
kFlashTransposedQsBF16,
kFlashTransposedQsInt16,
kSentinel,
};

View File

@ -558,57 +558,14 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
return scale;
}
// Reduces each of x and stores in following lanes of max (tested with float32)
template <class DF, typename T = hn::TFromD<DF>,
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
class VF = hn::Vec<DF>, typename F>
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
F reducer) {
const DF4 df4;
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
HWY_ALIGN T x_transposed[4 * kMaxLanes];
hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed);
VF x01 =
reducer(hn::Load(df, x_transposed), hn::Load(df, x_transposed + kLanes));
VF x23 = reducer(hn::Load(df, x_transposed + 2 * kLanes),
hn::Load(df, x_transposed + 3 * kLanes));
VF x0123 = reducer(x01, x23);
hn::Store(x0123, df, x_transposed);
VF4 result = hn::Load(df4, x_transposed);
for (int i = 1; i < kLanes / 4; ++i) {
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
}
return result;
}
// Returns vector with 8 lanes. Shouldn't be on architectures with less than 8
// lanes per vector.
template <class DF, typename T = hn::TFromD<DF>,
class DF8 = hn::CappedTag<T, 8>, class VF8 = hn::Vec<DF8>,
class VF = hn::Vec<DF>, typename F>
static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
VF x_5, VF x_6, VF x_7, F reducer) {
auto res0123 = Reduce4(df, x_0, x_1, x_2, x_3, reducer);
auto res4567 = Reduce4(df, x_4, x_5, x_6, x_7, reducer);
using DF4 = hn::CappedTag<T, 4>;
const DF4 df4;
const DF8 df8;
HWY_ALIGN T buf[8];
hn::Store(res0123, df4, buf);
hn::Store(res4567, df4, buf + 4);
return hn::Load(df8, buf);
}
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
float* HWY_RESTRICT scales) {
float* HWY_RESTRICT scales, float* HWY_RESTRICT q_scales_s = nullptr,
float max_v_scale = 1.0f) {
using DF4 = hn::CappedTag<float, 4>;
const DF4 df4;
using VF4 = hn::Vec<DF4>;
@ -636,6 +593,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
}
VF4 local_max = new_max;
VF4 old_max_vf = hn::Set(df4, kNegInf);
old_max_vf = hn::LoadU(df4, old_max);
new_max = hn::Max(new_max, old_max_vf);
@ -679,8 +637,28 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
const VF4 zero4 = hn::Zero(df4);
const VF4 one_over_d =
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
VF4 q_scale;
if (q_scales_s != nullptr) {
// max_s = exp(local_max - new_max) / old_d_vf
VF4 max_s = hn::Mul(one_over_d, hn::Exp(df4, hn::Sub(local_max, new_max)));
// Output the unquantize scale directly to array memory:
// Because we're capping out at 32767 / max_v_scale, the true scale goes up
// proportionately
hn::Store(hn::Mul(max_s, hn::Set(df4, max_v_scale / 32767.0f)), df4,
q_scales_s);
// multiplier for x = 32767 * exp(new_max - local_max) / max_v_scale
auto max_s_gt_0 = hn::Gt(max_s, zero4);
float inv_max_v_scale = 1.0f / std::max(max_v_scale, 1e-10f);
VF4 mult = hn::Mul(hn::Set(df4, 32767.0f * inv_max_v_scale),
hn::Exp(df4, hn::Sub(new_max, local_max)));
q_scale = hn::IfThenElse(max_s_gt_0, mult, zero4);
} else {
q_scale = one_over_d;
}
HWY_ALIGN float tmp_one_over_d[4];
hn::Store(one_over_d, df4, tmp_one_over_d);
hn::Store(q_scale, df4, tmp_one_over_d);
hn::BlendedStore(old_d_vf, changed_max, df4, old_d);
scale = hn::Mul(scale, one_over_d);
hn::BlendedStore(scale, changed_max, df4, scales);
@ -713,7 +691,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales) {
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales,
float* HWY_RESTRICT q_scales_s = nullptr, float max_v_scale = 1.0f) {
using DF8 = hn::CappedTag<float, 8>;
const DF8 df8;
using VF8 = hn::Vec<DF8>;
@ -755,6 +734,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
VF8 one_over_cap = hn::Set(df8, one_over_att_cap);
new_max = hn::Mul(cap, hn::Tanh(df8, hn::Mul(new_max, one_over_cap)));
}
VF8 local_max = new_max;
VF8 old_max_vf = hn::Set(df8, kNegInf);
old_max_vf = hn::LoadU(df8, old_max);
new_max = hn::Max(new_max, old_max_vf);
@ -817,8 +797,22 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
const VF8 zero8 = hn::Zero(df8);
const VF8 one_over_d =
hn::MaskedDivOr(zero8, non_zero_mask, hn::Set(df8, 1.0f), old_d_vf);
VF8 q_scale;
if (q_scales_s != nullptr) {
VF8 max_s = hn::Mul(one_over_d, hn::Exp(df8, hn::Sub(local_max, new_max)));
hn::Store(hn::Mul(max_s, hn::Set(df8, max_v_scale / 32767.0f)), df8,
q_scales_s);
auto max_s_gt_0 = hn::Gt(max_s, zero8);
float inv_max_v_scale = 1.0f / std::max(max_v_scale, 1e-10f);
VF8 mult = hn::Mul(hn::Set(df8, 32767.0f * inv_max_v_scale),
hn::Exp(df8, hn::Sub(new_max, local_max)));
q_scale = hn::IfThenElse(max_s_gt_0, mult, zero8);
} else {
q_scale = one_over_d;
}
HWY_ALIGN float tmp_one_over_d[8];
hn::Store(one_over_d, df8, tmp_one_over_d);
hn::Store(q_scale, df8, tmp_one_over_d);
hn::BlendedStore(old_d_vf, changed_max, df8, old_d);
scale = hn::Mul(scale, one_over_d);
hn::BlendedStore(scale, changed_max, df8, scales);
@ -862,7 +856,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t q_group_idx,
size_t kNumQueriesPerGroup) {
size_t kNumQueriesPerGroup, float* HWY_RESTRICT q_scales_s = nullptr,
float max_v_scale = 1.0f) {
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
[[maybe_unused]] constexpr int kSecondHalfAmountOfQueries =
kNumQueries - kFirstHalfAmountOfQueries;
@ -870,25 +865,30 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
FlashAttentionTileStepAndApplySoftCap4<kFirstHalfAmountOfQueries>(
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s,
max_v_scale);
} else {
#if HWY_MAX_BYTES <= 16
FlashAttentionTileStepAndApplySoftCap4<4>(
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s,
max_v_scale);
FlashAttentionTileStepAndApplySoftCap4<kSecondHalfAmountOfQueries>(
df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0,
x_6_p1, x_7_p0, x_7_p1,
old_max + (q_group_idx + 1) * kNumQueriesPerGroup,
old_d + (q_group_idx + 1) * kNumQueriesPerGroup,
scales + kNumQueriesPerGroup);
scales + kNumQueriesPerGroup,
q_scales_s == nullptr ? nullptr : q_scales_s + kNumQueriesPerGroup,
max_v_scale);
#else
FlashAttentionTileStepAndApplySoftCap8<kNumQueries>(
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1,
x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
old_d + (q_group_idx)*kNumQueriesPerGroup, scales, q_scales_s,
max_v_scale);
#endif
}
}
@ -998,6 +998,138 @@ static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidth(
}
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthInt16(
DF df, const int16_t* HWY_RESTRICT q, const int16_t* HWY_RESTRICT q2,
const int8_t* HWY_RESTRICT k_transposed_tile, size_t qkv_dim, VF& sum0_p0,
VF& sum0_p1, VF& sum1_p0, VF& sum1_p1, VF& sum2_p0, VF& sum2_p1,
VF& sum3_p0, VF& sum3_p1, VF& sum4_p0, VF& sum4_p1, VF& sum5_p0,
VF& sum5_p1, VF& sum6_p0, VF& sum6_p1, VF& sum7_p0, VF& sum7_p1) {
using DI16 = hn::ScalableTag<int16_t>;
const DI16 di16;
using VI16 = hn::Vec<DI16>;
using DI32 = hn::Repartition<int32_t, DF>;
const DI32 di32;
using VI32 = hn::Vec<DI32>;
HWY_DASSERT(hn::Lanes(di16) <= gcpp::KVCache::kTileSize);
HWY_DASSERT(kNumQueries <= 8);
HWY_DASSERT(gcpp::KVCache::kTileSize >= hn::Lanes(df) * 2);
VI32 isum0_p0 = hn::Zero(di32);
VI32 isum0_p1 = hn::Zero(di32);
VI32 isum1_p0 = hn::Zero(di32), isum1_p1 = hn::Zero(di32);
VI32 isum2_p0 = hn::Zero(di32), isum2_p1 = hn::Zero(di32);
VI32 isum3_p0 = hn::Zero(di32), isum3_p1 = hn::Zero(di32);
VI32 isum4_p0 = hn::Zero(di32), isum4_p1 = hn::Zero(di32);
VI32 isum5_p0 = hn::Zero(di32), isum5_p1 = hn::Zero(di32);
VI32 isum6_p0 = hn::Zero(di32), isum6_p1 = hn::Zero(di32);
VI32 isum7_p0 = hn::Zero(di32), isum7_p1 = hn::Zero(di32);
VI32 isum0_odd_p0 = hn::Zero(di32), isum0_odd_p1 = hn::Zero(di32);
VI32 isum1_odd_p0 = hn::Zero(di32), isum1_odd_p1 = hn::Zero(di32);
VI32 isum2_odd_p0 = hn::Zero(di32), isum2_odd_p1 = hn::Zero(di32);
VI32 isum3_odd_p0 = hn::Zero(di32), isum3_odd_p1 = hn::Zero(di32);
VI32 isum4_odd_p0 = hn::Zero(di32), isum4_odd_p1 = hn::Zero(di32);
VI32 isum5_odd_p0 = hn::Zero(di32), isum5_odd_p1 = hn::Zero(di32);
VI32 isum6_odd_p0 = hn::Zero(di32), isum6_odd_p1 = hn::Zero(di32);
VI32 isum7_odd_p0 = hn::Zero(di32), isum7_odd_p1 = hn::Zero(di32);
const int32_t* q_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q);
const int32_t* q2_int32_ptr = HWY_RCAST_ALIGNED(const int32_t*, q2);
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
kNumQueries - kFirstHalfAmountOfQueries;
const hn::Repartition<int8_t, DI16> di8;
const hn::Half<decltype(di8)> di8_half;
for (size_t i = 0; i < qkv_dim / 2; i++) {
auto k_dim0 = hn::LoadU(
di8_half, k_transposed_tile + (i * 2) * gcpp::KVCache::kTileSize);
auto k_dim1 = hn::LoadU(di8_half, k_transposed_tile +
(i * 2) * gcpp::KVCache::kTileSize +
hn::Lanes(di8_half));
auto k_vec1 = hn::PromoteTo(di16, k_dim0);
auto k_vec2 = hn::PromoteTo(di16, k_dim1);
auto accumulate = [&](int32_t q_val, VI32& sum_p0, VI32& sum_p1,
VI32& sum_odd_p0, VI32& sum_odd_p1) HWY_ATTR {
VI16 q_vec = hn::BitCast(di16, hn::Set(di32, q_val));
sum_p0 = hn::ReorderWidenMulAccumulate(di32, k_vec1, q_vec, sum_p0,
sum_odd_p0);
sum_p1 = hn::ReorderWidenMulAccumulate(di32, k_vec2, q_vec, sum_p1,
sum_odd_p1);
};
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries], isum0_p0, isum0_p1,
isum0_odd_p0, isum0_odd_p1);
if constexpr (kNumQueries >= 2) {
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 1], isum1_p0,
isum1_p1, isum1_odd_p0, isum1_odd_p1);
}
if constexpr (kNumQueries >= 3) {
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 2], isum2_p0,
isum2_p1, isum2_odd_p0, isum2_odd_p1);
}
if constexpr (kNumQueries >= 4) {
accumulate(q_int32_ptr[i * kFirstHalfAmountOfQueries + 3], isum3_p0,
isum3_p1, isum3_odd_p0, isum3_odd_p1);
}
if constexpr (kNumQueries >= 5) {
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 0], isum4_p0,
isum4_p1, isum4_odd_p0, isum4_odd_p1);
}
if constexpr (kNumQueries >= 6) {
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 1], isum5_p0,
isum5_p1, isum5_odd_p0, isum5_odd_p1);
}
if constexpr (kNumQueries >= 7) {
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 2], isum6_p0,
isum6_p1, isum6_odd_p0, isum6_odd_p1);
}
if constexpr (kNumQueries >= 8) {
accumulate(q2_int32_ptr[i * kSecondHalfAmountOfQueries + 3], isum7_p0,
isum7_p1, isum7_odd_p0, isum7_odd_p1);
}
}
auto convert_to_float = [&](const VI32& sum_p0, const VI32& sum_odd_p0,
const VI32& sum_p1, const VI32& sum_odd_p1,
VF& out_p0, VF& out_p1) HWY_ATTR {
out_p0 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(sum_p0, sum_odd_p0));
out_p1 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(sum_p1, sum_odd_p1));
};
convert_to_float(isum0_p0, isum0_odd_p0, isum0_p1, isum0_odd_p1, sum0_p0,
sum0_p1);
if constexpr (kNumQueries >= 2) {
convert_to_float(isum1_p0, isum1_odd_p0, isum1_p1, isum1_odd_p1, sum1_p0,
sum1_p1);
}
if constexpr (kNumQueries >= 3) {
convert_to_float(isum2_p0, isum2_odd_p0, isum2_p1, isum2_odd_p1, sum2_p0,
sum2_p1);
}
if constexpr (kNumQueries >= 4) {
convert_to_float(isum3_p0, isum3_odd_p0, isum3_p1, isum3_odd_p1, sum3_p0,
sum3_p1);
}
if constexpr (kNumQueries >= 5) {
convert_to_float(isum4_p0, isum4_odd_p0, isum4_p1, isum4_odd_p1, sum4_p0,
sum4_p1);
}
if constexpr (kNumQueries >= 6) {
convert_to_float(isum5_p0, isum5_odd_p0, isum5_p1, isum5_odd_p1, sum5_p0,
sum5_p1);
}
if constexpr (kNumQueries >= 7) {
convert_to_float(isum6_p0, isum6_odd_p0, isum6_p1, isum6_odd_p1, sum6_p0,
sum6_p1);
}
if constexpr (kNumQueries >= 8) {
convert_to_float(isum7_p0, isum7_odd_p0, isum7_p1, isum7_odd_p1, sum7_p0,
sum7_p1);
}
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>, typename T>
static HWY_INLINE void QDotKTilexUpTo8TransposedKDoubleWidthBF16(
DF df, const BF16* HWY_RESTRICT q, const BF16* HWY_RESTRICT q2,
@ -1306,6 +1438,47 @@ static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
}
}
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static HWY_INLINE void ApplyQuantizationScale(
DF df, const float* HWY_RESTRICT q_scales, int q_group_idx,
int kNumQueriesPerGroup, VF& x0_p0, VF& x0_p1, VF& x1_p0, VF& x1_p1,
VF& x2_p0, VF& x2_p1, VF& x3_p0, VF& x3_p1, VF& x4_p0, VF& x4_p1, VF& x5_p0,
VF& x5_p1, VF& x6_p0, VF& x6_p1, VF& x7_p0, VF& x7_p1) {
auto apply_scale = [&](int group_offset, int query_offset, VF& x_p0,
VF& x_p1) HWY_ATTR {
int scale_idx =
(q_group_idx + group_offset) * kNumQueriesPerGroup + query_offset;
VF s = hn::Set(df, q_scales[scale_idx]);
x_p0 = hn::Mul(x_p0, s);
x_p1 = hn::Mul(x_p1, s);
};
if constexpr (kNumQueries >= 1) {
apply_scale(0, 0, x0_p0, x0_p1);
}
if constexpr (kNumQueries >= 2) {
apply_scale(0, 1, x1_p0, x1_p1);
}
if constexpr (kNumQueries >= 3) {
apply_scale(0, 2, x2_p0, x2_p1);
}
if constexpr (kNumQueries >= 4) {
apply_scale(0, 3, x3_p0, x3_p1);
}
if constexpr (kNumQueries >= 5) {
apply_scale(1, 0, x4_p0, x4_p1);
}
if constexpr (kNumQueries >= 6) {
apply_scale(1, 1, x5_p0, x5_p1);
}
if constexpr (kNumQueries >= 7) {
apply_scale(1, 2, x6_p0, x6_p1);
}
if constexpr (kNumQueries >= 8) {
apply_scale(1, 3, x7_p0, x7_p1);
}
}
// Performs tiled flash attention for arbitrary number of queries
// It depends on kv being tiled.
// Runs 2 loops one over tiles, and inner one over queries(up to 4 at a time).
@ -1317,6 +1490,7 @@ static HWY_INLINE void MultiplyByScale(DF df, const BF16* scales, VF& x0_p0,
// as it will be used to figure out when to switch to the next one.
// q_T_in_groups_up_to_4 - Span of float* All except last float*
// should have (qkv_dim, 4) Last one can have any size up to 4.
// q_scales - Span of float of shape (q_count,) used for Queries in int16 format
// start_pos_per_query - start position in kv to start attention from ()
// last_pos_per_query - last position in kv to attend to (exclusive)
// queries_per_timestep - how many queries begin/end on the same timestep
@ -1332,6 +1506,7 @@ template <typename KV_T, typename Q_T>
HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
const hwy::Span<const MatPtrT<KV_T>> kvs, int q_count,
const hwy::Span<const Q_T * HWY_RESTRICT> q_T_in_groups_up_to_4,
const hwy::Span<const float> q_scales,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
@ -1406,6 +1581,12 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
size_t current_kv_start_offset = 0;
size_t current_kv_idx = 0;
HWY_ALIGN float q_scales_s[8];
float* q_scales_s_ptr = nullptr;
if constexpr (IsInt16<Q_T>()) {
q_scales_s_ptr = q_scales_s;
}
float max_v_scale = 1.0f;
auto inner_loop = [&]<int kNumQueries>(int q_group_idx) HWY_ATTR {
int loop_idx = q_group_idx / (kNumQueriesPerLoop / kNumQueriesPerGroup);
if (position + step_size <= min_start_pos_per_group[loop_idx] ||
@ -1429,6 +1610,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
if (kNumQueries > 4) {
q2_group = q_T_in_groups_up_to_4[q_group_idx + 1];
}
if constexpr (IsF32<Q_T>()) {
const KV_T* k_transposed_tile = tile_base + pos_in_tile;
QDotKTilexUpTo8TransposedKDoubleWidth<kNumQueries>(
@ -1441,10 +1623,16 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
} else if constexpr (IsInt16<Q_T>()) {
const KV_T* k_transposed_tile = tile_base + pos_in_tile * 2;
QDotKTilexUpTo8TransposedKDoubleWidthInt16<kNumQueries>(
df, q_group, q2_group, k_transposed_tile, qkv_dim, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1, x_4_p_0,
x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
} else {
static_assert(
false,
"Query type type not supported, only float and BF16 are supported");
static_assert(false,
"Query type not supported, only float, BF16, and "
"Int16 are supported");
}
// microscaling
// TODO: Change to more generic function to inform if we should use
@ -1461,6 +1649,13 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1,
x_6_p_0, x_6_p_1, x_7_p_0, x_7_p_1);
}
if constexpr (IsInt16<Q_T>()) {
ApplyQuantizationScale<kNumQueries>(
df, q_scales.data(), q_group_idx, kNumQueriesPerGroup, x_0_p_0,
x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, x_3_p_1,
x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1, x_7_p_0,
x_7_p_1);
}
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
constexpr int kSecondHalfAmountOfQueries =
@ -1485,15 +1680,29 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
x_7_p_0, x_7_p_1);
}
HWY_ALIGN float scales[kNumQueriesPerLoop];
// HWY_UNROLL(kNumQueriesPerLoop)
for (size_t i = 0; i < kNumQueriesPerLoop; ++i) {
scales[i] = 1.0f;
}
if constexpr (IsInt16<Q_T>() && kUseMicroScaling) {
if (q_group_idx == 0) { // update only when needed
const BF16* microscaling_scales_v =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
kTileSize + pos_in_tile;
const PackedSpan<const BF16> scales_span =
MakeConstSpan(microscaling_scales_v, 2 * hn::Lanes(df));
VF v_scales_p0, v_scales_p1;
Decompress2(df, scales_span, 0, v_scales_p0, v_scales_p1);
max_v_scale = hn::ReduceMax(df, hn::Max(v_scales_p0, v_scales_p1));
}
}
FlashAttentionTileStepAndApplySoftCap<kNumQueries>(
df, 0.0f, 1.0f, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0, x_6_p_1,
x_7_p_0, x_7_p_1, max_logits, exp_denominator_sums, scales, q_group_idx,
kNumQueriesPerGroup);
kNumQueriesPerGroup, q_scales_s_ptr, max_v_scale);
if constexpr (kUseMicroScaling) {
const BF16* microscaling_scales_v =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
@ -1508,7 +1717,13 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx]);
} else if constexpr (IsBF16<Q_T>()) {
} else if constexpr (IsInt16<Q_T>()) {
MulByConstAndAddTileUpTo8_BF16_Int16<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
x_6_p_1, x_7_p_0, x_7_p_1, v_tile, att_out_per_query[loop_idx],
q_scales_s);
} else {
MulByConstAndAddTileUpTo8_BF16<kNumQueries>(
df, scales, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1,
x_3_p_0, x_3_p_1, x_4_p_0, x_4_p_1, x_5_p_0, x_5_p_1, x_6_p_0,
@ -1558,7 +1773,7 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
return TileFlashAttentionReturnExpSumsAndMaxLogits(
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
kv_t, q_count, q_T_in_groups_up_to_4, {}, start_pos_per_query,
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
});
}
@ -1572,11 +1787,31 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
return TileFlashAttentionReturnExpSumsAndMaxLogits(
kv_t, q_count, q_T_in_groups_up_to_4, start_pos_per_query,
kv_t, q_count, q_T_in_groups_up_to_4, {}, start_pos_per_query,
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
});
}
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
hwy::Span<const MatPtr> kvs, int q_count,
const hwy::Span<const int16_t* HWY_RESTRICT> q_T_in_groups_up_to_4,
const hwy::Span<const float> q_scales,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
for ([[maybe_unused]] auto&& mat : kvs) {
HWY_DASSERT(mat.GetType() == Type::kInt8);
}
auto matptrs = MakeMatPtrVec<int8_t>(kvs);
hwy::Span<const MatPtrT<int8_t>> matptrs_span(matptrs.data(), matptrs.size());
return TileFlashAttentionReturnExpSumsAndMaxLogits(
matptrs_span, q_count, q_T_in_groups_up_to_4, q_scales,
start_pos_per_query, last_pos_per_query, att_cap, att_out,
exp_denominator_sums, max_logits);
}
// Implements flash attention for a strip of tiles of size 1, 4 or 8 query
// vectors by 2NF positions in K.
// It iterates through tiles in K from `params.min_start_pos / 2NF * 2NF` up to

View File

@ -65,6 +65,16 @@ namespace gcpp {
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( \
hwy::Span<const MatPtr> kvs, int q_count, \
const hwy::Span<const int16_t* HWY_RESTRICT> q_T_in_groups_up_to_4, \
hwy::Span<const float> q_scales, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -56,6 +56,7 @@
#include "gemma/attention.h"
#include "gemma/configs.h"
#include "gemma/flash_attention.h"
#include "gemma/tiled_attention.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
@ -416,25 +417,22 @@ void TestTiledFlashAttentionBF16() {
ctx.allocator, MatPadding::kPacked);
PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim);
std::vector<BF16> q_float(num_queries_per_timestep * qkv_dim);
std::vector<BF16> q_float2(num_queries_per_timestep * qkv_dim);
// fill in qs with predictable, synthetic data
for (size_t i = 0; i < num_queries_per_timestep; ++i) {
for (size_t j = 0; j < qkv_dim; j += 2) {
q_float[j * num_queries_per_timestep + i * 2] =
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 1));
q_float[j * num_queries_per_timestep + i * 2 + 1] =
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 2));
q_float2[j * num_queries_per_timestep + i * 2] =
hwy::ConvertScalarTo<BF16>(
0.01f * (i + num_queries_per_timestep + 1) / (j + 1));
q_float2[j * num_queries_per_timestep + i * 2 + 1] =
hwy::ConvertScalarTo<BF16>(
0.01f * (i + num_queries_per_timestep + 1) / (j + 2));
std::vector<float> q_all(num_queries * qkv_dim);
for (size_t i = 0; i < num_queries; ++i) {
for (size_t j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
}
}
const BF16* q_T[2] = {q_float.data(), q_float2.data()};
std::vector<float*> q_ptrs(num_queries);
for (int i = 0; i < num_queries; ++i) {
q_ptrs[i] = q_all.data() + i * qkv_dim;
}
auto [transposed_queries, transposed_queries_ptrs, _] =
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(hwy::Span<float*>(q_ptrs),
qkv_dim, /*group_size=*/4);
hwy::Span<const BF16*> q_T(
const_cast<const BF16**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
@ -465,8 +463,7 @@ void TestTiledFlashAttentionBF16() {
}
hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kvs, num_queries, hwy::Span<const BF16*>(q_T, 2),
hwy::Span<const size_t>(start_pos_per_query),
kvs, num_queries, q_T, hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
@ -482,6 +479,10 @@ void TestTiledFlashAttentionBF16() {
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (size_t j = 0; j < qkv_dim; ++j) {
if (j == 0) {
std::cerr << "att_out[0][" << j << "]=" << att_out.Row(i)[j]
<< " gold=" << att_out_gold[i * qkv_dim + j] << "\n";
}
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f);
}
}
@ -577,6 +578,186 @@ void TestTiledFlashAttentionInt8() {
}
}
void TestTiledFlashAttentionInt8BF16() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data matching BF16 paired layout
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim);
std::vector<float> q_all(num_queries * qkv_dim);
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
}
}
std::vector<float*> q_ptrs(num_queries);
for (int i = 0; i < num_queries; ++i) {
q_ptrs[i] = q_all.data() + i * qkv_dim;
}
auto [transposed_queries, transposed_queries_ptrs, _] =
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(hwy::Span<float*>(q_ptrs),
qkv_dim, /*group_size=*/4);
hwy::Span<const BF16*> q_T(
const_cast<const BF16**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
using DF = hn::ScalableTag<float>;
const DF df;
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
std::vector<float> max_logits(num_queries_rounded_to_laness);
for (size_t i = 0; i < num_queries; ++i) {
hwy::ZeroBytes(att_out.Row(i),
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
exp_denominator_sums[i] = 0.0f;
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
}
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
}
}
hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kvs, num_queries, q_T, hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f);
}
}
}
void TestTiledFlashAttentionInt8Int16() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data matching BF16 paired layout
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8TwoTranspositions, qkv_dim);
std::vector<float> q_all(num_queries * qkv_dim);
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
q_all[i * qkv_dim + j] = 0.01f * (i + 1) / (j + 1);
}
}
std::vector<float*> q_ptrs(num_queries);
for (int i = 0; i < num_queries; ++i) {
q_ptrs[i] = q_all.data() + i * qkv_dim;
}
auto [transposed_queries, transposed_queries_ptrs, q_scales] =
TransposeQueriesToGroupsOfNBF16orInt16<int16_t>(
hwy::Span<float*>(q_ptrs), qkv_dim, /*group_size=*/4);
hwy::Span<const int16_t*> q_T(
const_cast<const int16_t**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
MatStorageT<float> att_out("att_out", Extents2D(num_queries, qkv_dim),
ctx.allocator, MatPadding::kPacked);
using DF = hn::ScalableTag<float>;
const DF df;
HWY_LANES_CONSTEXPR size_t lanes = hn::Lanes(df);
size_t num_queries_rounded_to_laness = hwy::RoundUpTo(num_queries, lanes);
std::vector<float> exp_denominator_sums(num_queries_rounded_to_laness);
std::vector<float> max_logits(num_queries_rounded_to_laness);
for (size_t i = 0; i < num_queries; ++i) {
hwy::ZeroBytes(att_out.Row(i),
qkv_dim * sizeof(decltype(att_out.Row(i)[0])));
exp_denominator_sums[i] = 0.0f;
max_logits[i] = -std::numeric_limits<float>::max() / 2.0f;
}
std::vector<size_t, hwy::AlignedAllocator<size_t>> start_pos_per_query;
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
}
}
hwy::Span<const MatPtr> kvs(&kv, 1);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
kvs, num_queries, q_T, q_scales,
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query), att_cap, att_out,
exp_denominator_sums.data(), max_logits.data());
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f);
}
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
@ -590,6 +771,8 @@ HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestAttention);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttention);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionBF16);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8BF16);
HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8Int16);
HWY_AFTER_TEST();
} // namespace gcpp

View File

@ -82,7 +82,8 @@ void Attention(LayerAttentionType type, const size_t num_tokens,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, QBatch& qbatch, MatMulEnv& env) {
if (activations.attention_impl == AttentionImpl::kFlashTransposedQs ||
activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
activations.attention_impl == AttentionImpl::kFlashTransposedQsBF16 ||
activations.attention_impl == AttentionImpl::kFlashTransposedQsInt16) {
TiledAttention(
activations.attention_impl, num_tokens, layer_idx, layer,
activations.attention, qbatch, env,

View File

@ -76,24 +76,37 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args,
const RuntimeConfig& runtime_config,
const Allocator& allocator)
: allocator_(allocator) {
// clang-format off
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs ||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsInt16 ||
runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|| ((runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs
) &&
hwy::IsSame<KV_t, BF16>())) {
) {
// clang-format on
const size_t num_tiles =
hwy::DivCeil(CappedSeqLen(config, inference_args), kTileSize);
tiled_seq_len = num_tiles * kTileSize;
Type kv_cache_type;
if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16
|| hwy::IsSame<KV_t, BF16>()) {
) {
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16);
} else if (runtime_config.attention_impl ==
AttentionImpl::kFlashTransposedQsInt16) {
if (runtime_config.kv_cache_type.has_value() &&
runtime_config.kv_cache_type.value() != Type::kInt8) {
HWY_WARN(
"You are have set kv_cache_type to %s, but you are using "
"FlashTransposedQsInt16 attention implementation which only "
"supports Int8. kv_cache_type will be set to Int8.",
runtime_config.kv_cache_type.value());
}
kv_cache_type = Type::kInt8;
} else {
kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kF32);
}
int tile_length = 2 * config.layer_configs[0].qkv_dim * kTileSize;
if (kv_cache_type == Type::kInt8) {
// microscaling
tile_length += 2 * sizeof(BF16) * kTileSize;
}
auto num_tiles_per_head = [](size_t window_size, size_t prefill_tbatch_size,

View File

@ -1,9 +1,11 @@
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <limits>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
@ -127,6 +129,10 @@ static HWY_INLINE void ComputeQKVTransposedTile(
// Apply positional encodings and store K/V in tiled format.
hwy::Divisor div_kv_heads(kv_heads);
bool is_transposed_qs =
attention_impl == AttentionImpl::kFlashTransposedQsBF16
|| attention_impl == AttentionImpl::kFlashTransposedQsInt16;
hn::ScalableTag<float> df;
static hwy::Divisor tile_size_divisor(KVCache::kTileSize);
ParallelFor(
@ -249,7 +255,7 @@ static HWY_INLINE void ComputeQKVTransposedTile(
v_cache_values = v_buf;
}
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
if (is_transposed_qs) {
const int in_tile_idx_mod_2 = in_tile_idx % 2;
for (int dim = 0; dim < qkv_dim; dim += 2) {
const int dim_mod_2 = dim % 2;
@ -280,7 +286,7 @@ static HWY_INLINE void ComputeQKVTransposedTile(
}
Compress(k_tile_vec, qkv_dim * KVCache::kTileSize, tls,
tile_packed_span, 0);
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
if (is_transposed_qs) {
Compress(v_tile_vec, qkv_dim * KVCache::kTileSize, tls,
tile_packed_span, qkv_dim * KVCache::kTileSize);
}
@ -289,23 +295,6 @@ static HWY_INLINE void ComputeQKVTransposedTile(
});
}
// TODO: optimize with gathers
// This format might change in the future, when kernel will be updated to
// support more than 8 queries.
// Input (num_queries, qkv_dim)
// Output (qkv_dim, num_queries)
void TransposeQ(const MatPtrT<float>& queries,
hwy::Span<float> transposed_queries_span) {
const size_t qkv_dim = queries.Cols();
const size_t num_queries = queries.Rows();
HWY_ASSERT(transposed_queries_span.size() == num_queries * qkv_dim);
for (size_t i = 0; i < qkv_dim; i++) {
for (size_t j = 0; j < num_queries; ++j) {
transposed_queries_span[i * num_queries + j] = queries.Row(j)[i];
}
}
}
// Transposes queries
// Input: vector of pointers to subsequent queries. (allows for arbitrary
// strides)
@ -375,6 +364,144 @@ std::pair<AlignedFloatVector, std::vector<float*>> TransposeQueriesToGroupsOf4(
std::move(transposed_queries_ptrs));
}
template <typename OutT>
static HWY_INLINE void TransposeStridedQueriesBF16orInt16(
hwy::Span<const float*> queries, int qkv_dim,
hwy::Span<OutT> transposed_queries, hwy::Span<float> q_scales) {
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
const DF df;
using VF = hn::Vec<DF>;
// doubles to avoid moving between int/float domains when gathering
using DF64 = hn::ScalableTag<double>;
const DF64 dd64;
using DI64 = hn::ScalableTag<int64_t>;
const DI64 di64;
using VI64 = hn::Vec<DI64>;
auto d_out = hn::Rebind<OutT, decltype(df)>();
const size_t lanes = hn::Lanes(df);
const size_t half_lanes = lanes / 2;
const size_t num_queries = queries.size();
const size_t num_numbers_to_gather = num_queries * 2;
const size_t num_queries_rounded_up = hwy::RoundUpTo(num_queries, half_lanes);
const size_t num_scales_rounded_up =
hwy::RoundUpTo(num_numbers_to_gather, lanes);
// We store scales twice so we will be able to just load them without a need
// to duplicate for multiplication
AlignedFloatVector inverted_q_scales_doubled(num_scales_rounded_up);
if constexpr (IsInt16<OutT>()) {
// compute microscales
for (size_t i = 0; i < num_queries; ++i) {
float max_abs = AbsMaxOfSpan(hwy::Span<const float>(queries[i], qkv_dim));
float scale = max_abs == 0.0f ? 1.0f : 32767.0f / max_abs;
inverted_q_scales_doubled[2 * i] = scale;
inverted_q_scales_doubled[2 * i + 1] = scale;
q_scales[i] = 1.0f / scale;
}
}
std::vector<int64_t, hwy::AlignedAllocator<int64_t>> query_offsets(
num_queries_rounded_up);
for (size_t i = 0; i < num_queries; ++i) {
query_offsets[i] = (queries[i] - queries[0]) / 2;
}
for (size_t i = num_queries; i < num_queries_rounded_up; ++i) {
// last offset is the same so gather doesn't read out of bounds
query_offsets[i] = query_offsets[num_queries > 0 ? num_queries - 1 : 0];
}
const double* queries_0_double = HWY_RCAST_ALIGNED(const double*, queries[0]);
// Lambda to handle the scaling and demotion for Int16 types.
auto process_values = [&]() HWY_ATTR {
if constexpr (IsInt16<OutT>()) {
return [&](VF& x, size_t j) HWY_ATTR {
VF scales = hn::Load(df, inverted_q_scales_doubled.data() + j * 2);
x = hn::Mul(x, scales);
return hn::DemoteTo(d_out, hn::NearestInt(x));
};
} else {
return [&](VF& x, size_t j) HWY_ATTR { return hn::DemoteTo(d_out, x); };
}
}();
for (size_t i = 0; i < qkv_dim; i += 2) {
size_t j = 0;
if (num_queries >= half_lanes) {
for (; j <= num_queries - half_lanes; j += half_lanes) {
const VI64 offsets = hn::LoadU(di64, query_offsets.data() + j);
auto x64 = hn::GatherIndex(dd64, queries_0_double + i / 2, offsets);
VF x = hn::BitCast(df, x64);
if constexpr (IsInt16<OutT>()) {
auto demoted = process_values(x, j);
hn::Store(demoted, d_out,
transposed_queries.data() + i * num_queries + j * 2);
} else if constexpr (IsBF16<OutT>()) {
auto demoted = hn::DemoteTo(d_out, x);
hn::Store(demoted, d_out,
transposed_queries.data() + i * num_queries + j * 2);
} else {
static_assert(false, "Unsupported type");
}
}
}
if (j < num_queries) {
const VI64 offsets = hn::LoadU(di64, query_offsets.data() + j);
auto x64 = hn::GatherIndex(dd64, queries_0_double + i / 2, offsets);
VF x = hn::BitCast(df, x64);
if constexpr (IsInt16<OutT>()) {
auto demoted = process_values(x, j);
hn::StoreN(demoted, d_out,
transposed_queries.data() + i * num_queries + j * 2,
num_numbers_to_gather - j * 2);
} else if constexpr (IsBF16<OutT>()) {
auto demoted = hn::DemoteTo(d_out, x);
hn::StoreN(demoted, d_out,
transposed_queries.data() + i * num_queries + j * 2,
num_numbers_to_gather - j * 2);
} else {
static_assert(false, "Unsupported type");
}
}
}
}
// Transposed queries data, vector of pointers to transposed queries, vector of
// scales
template <typename OutT>
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, std::vector<OutT*>,
AlignedFloatVector>
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs,
int qkv_dim, size_t group_size) {
size_t num_queries = queries_ptrs.size();
size_t num_groups = hwy::DivCeil(num_queries, group_size);
std::vector<OutT, hwy::AlignedAllocator<OutT>> transposed_queries(
num_groups * group_size * qkv_dim);
std::vector<OutT*> transposed_queries_ptrs;
AlignedFloatVector q_scales(num_groups * 4);
for (size_t group_idx = 0; group_idx < num_groups; ++group_idx) {
size_t current_group_size =
std::min(group_size, num_queries - group_idx * group_size);
transposed_queries_ptrs.push_back(transposed_queries.data() +
group_idx * qkv_dim * group_size);
TransposeStridedQueriesBF16orInt16(
hwy::Span<const float*>(
const_cast<const float**>(queries_ptrs.data() +
group_idx * group_size),
current_group_size),
qkv_dim,
hwy::Span<OutT>(transposed_queries_ptrs.back(),
qkv_dim * current_group_size),
hwy::Span<float>(q_scales.data() + group_idx * group_size,
current_group_size));
}
return std::make_tuple(std::move(transposed_queries),
std::move(transposed_queries_ptrs),
std::move(q_scales));
}
std::pair<AlignedBF16Vector, std::vector<BF16*>>
TransposeTransposedQueriesAndPackIntoBF16(hwy::Span<float*> queries_ptrs,
int qkv_dim, int num_queries) {
@ -537,9 +664,6 @@ void LocalAttentionForAllHeadsTokensAndBatch(
hwy::Span<float*> queries_ptrs_span(queries_ptrs.data(),
queries_ptrs.size());
auto [transposed_queries, transposed_queries_ptrs] =
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
MatStorageT<float>& att_out =
activations.sub_task_att_out->at(task_idx);
AlignedFloatVector& exp_denominator_sums =
@ -604,23 +728,37 @@ void LocalAttentionForAllHeadsTokensAndBatch(
last_pos_per_query.push_back(query_last_context_pos);
}
}
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16) {
// pack transposed queries into BF16
hwy::Span<float*> queries_span(transposed_queries_ptrs.data(),
transposed_queries_ptrs.size());
auto [_, transposed_queries_ptrs_bf16] =
TransposeTransposedQueriesAndPackIntoBF16(queries_span, qkv_dim,
num_queries);
hwy::Span<const BF16*> queries_span_bf16(
const_cast<const BF16**>(transposed_queries_ptrs_bf16.data()),
transposed_queries_ptrs_bf16.size());
auto [transposed_queries, transposed_queries_ptrs, _] =
TransposeQueriesToGroupsOfNBF16orInt16<BF16>(
queries_ptrs_span, qkv_dim, /*group_size=*/4);
hwy::Span<const BF16*> queries_span(
const_cast<const BF16**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kv_ptrs, num_queries, queries_span_bf16,
kv_ptrs, num_queries, queries_span,
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query),
activations.config.att_cap, att_out, exp_denominator_sums.data(),
max_logits.data());
} else if (attention_impl == AttentionImpl::kFlashTransposedQsInt16) {
auto [transposed_queries, transposed_queries_ptrs, q_scales] =
TransposeQueriesToGroupsOfNBF16orInt16<int16_t>(
queries_ptrs_span, qkv_dim, /*group_size=*/4);
hwy::Span<const int16_t*> queries_span(
const_cast<const int16_t**>(transposed_queries_ptrs.data()),
transposed_queries_ptrs.size());
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
kv_ptrs, num_queries, queries_span, q_scales,
hwy::Span<const size_t>(start_pos_per_query),
hwy::Span<const size_t>(last_pos_per_query),
activations.config.att_cap, att_out, exp_denominator_sums.data(),
max_logits.data());
} else {
auto [transposed_queries, transposed_queries_ptrs] =
TransposeQueriesToGroupsOf4(queries_ptrs_span, qkv_dim);
DispatchTileFlashAttentionReturnExpSumsAndMaxLogits(
kv_ptrs, num_queries,
hwy::Span<const float*>(
@ -712,8 +850,9 @@ void TiledAttention(AttentionImpl attention_impl, size_t num_tokens,
ComputeQKVTransposedTile<BF16>(num_tokens, layer_idx, layer, attention_impl,
activations, qbatch, flags, env);
} else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() == Type::kF32) {
ComputeQKVTransposedTile<KV_t>(num_tokens, layer_idx, layer, attention_impl,
activations, qbatch, flags, env);
ComputeQKVTransposedTile<float>(num_tokens, layer_idx, layer,
attention_impl, activations, qbatch, flags,
env);
} else if (qbatch.KV(0).cache->compact_kv_cache_ptr.GetType() ==
Type::kInt8) {
ComputeQKVTransposedTile<int8_t>(num_tokens, layer_idx, layer,

View File

@ -28,6 +28,12 @@ namespace gcpp {
const size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
\
template <typename OutT> \
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \
std::vector<OutT*>, AlignedFloatVector> \
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \
int qkv_dim, size_t group_size); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -1078,6 +1078,148 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8(
HWY_DASSERT(qkv_dim == i);
}
// Specialized version for BF16 models that uses int16 quantization for V.
template <int32_t N, class DF, class VF = hn::Vec<DF>>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16_Int16(
DF df, const float* HWY_RESTRICT scales, const VF& c0_p0, const VF& c0_p1,
const VF& c1_p0, const VF& c1_p1, const VF& c2_p0, const VF& c2_p1,
const VF& c3_p0, const VF& c3_p1, const VF& c4_p0, const VF& c4_p1,
const VF& c5_p0, const VF& c5_p1, const VF& c6_p0, const VF& c6_p1,
const VF& c7_p0, const VF& c7_p1, const int8_t* HWY_RESTRICT v_tile,
MatPtrT<float>& out, const float* HWY_RESTRICT q_scales_s) {
static_assert(N <= 8);
namespace hn = hwy::HWY_NAMESPACE;
const size_t qkv_dim = out.Cols();
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
using DI16 = hn::Repartition<int16_t, DF>;
const DI16 di16;
const auto di16_half = hn::Half<DI16>();
using DI32 = hn::Repartition<int32_t, DF>;
const DI32 di32;
using VI16 = hn::Vec<DI16>;
using VI32 = hn::Vec<DI32>;
using DI8 = hn::Repartition<int8_t, DF>;
const hn::Half<DI8> di8_half;
HWY_LANES_CONSTEXPR size_t kInt16Lanes = hn::Lanes(di16);
HWY_ALIGN int16_t cs_i16[N * kMaxLanes * 2];
auto quantize_s_and_store = [&](int j, const VF& p0, const VF& p1) HWY_ATTR {
auto i0 =
hn::OrderedDemote2To(di16, hn::NearestInt(p0), hn::NearestInt(p1));
hn::Store(i0, di16, cs_i16 + j * kMaxLanes * 2);
};
quantize_s_and_store(0, c0_p0, c0_p1);
if constexpr (N >= 2) quantize_s_and_store(1, c1_p0, c1_p1);
if constexpr (N >= 3) quantize_s_and_store(2, c2_p0, c2_p1);
if constexpr (N >= 4) quantize_s_and_store(3, c3_p0, c3_p1);
if constexpr (N >= 5) quantize_s_and_store(4, c4_p0, c4_p1);
if constexpr (N >= 6) quantize_s_and_store(5, c5_p0, c5_p1);
if constexpr (N >= 7) quantize_s_and_store(6, c6_p0, c6_p1);
if constexpr (N >= 8) quantize_s_and_store(7, c7_p0, c7_p1);
size_t i = 0;
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
while (i + 2 * NF <= qkv_dim) {
VI32 acc0_0 = hn::Zero(di32), acc0_1 = hn::Zero(di32);
VI32 acc1_0 = hn::Zero(di32), acc1_1 = hn::Zero(di32);
VI32 acc2_0 = hn::Zero(di32), acc2_1 = hn::Zero(di32);
VI32 acc3_0 = hn::Zero(di32), acc3_1 = hn::Zero(di32);
VI32 acc4_0 = hn::Zero(di32), acc4_1 = hn::Zero(di32);
VI32 acc5_0 = hn::Zero(di32), acc5_1 = hn::Zero(di32);
VI32 acc6_0 = hn::Zero(di32), acc6_1 = hn::Zero(di32);
VI32 acc7_0 = hn::Zero(di32), acc7_1 = hn::Zero(di32);
VI32 acc0_o_0 = hn::Zero(di32), acc0_o_1 = hn::Zero(di32);
VI32 acc1_o_0 = hn::Zero(di32), acc1_o_1 = hn::Zero(di32);
VI32 acc2_o_0 = hn::Zero(di32), acc2_o_1 = hn::Zero(di32);
VI32 acc3_o_0 = hn::Zero(di32), acc3_o_1 = hn::Zero(di32);
VI32 acc4_o_0 = hn::Zero(di32), acc4_o_1 = hn::Zero(di32);
VI32 acc5_o_0 = hn::Zero(di32), acc5_o_1 = hn::Zero(di32);
VI32 acc6_o_0 = hn::Zero(di32), acc6_o_1 = hn::Zero(di32);
VI32 acc7_o_0 = hn::Zero(di32), acc7_o_1 = hn::Zero(di32);
for (int lane = 0; lane < NF; ++lane) {
VI16 vi_first8, vi_next8;
const int8_t* v_ptr = v_tile + 2 * qkv_dim * lane + i * 2;
auto v8_t0 = hn::LoadU(di8_half, v_ptr);
auto v16_t0 = hn::PromoteTo(di16, v8_t0);
auto v8_t1 = hn::LoadU(di8_half, v_ptr + kInt16Lanes);
auto v16_t1 = hn::PromoteTo(di16, v8_t1);
vi_first8 = v16_t0;
vi_next8 = v16_t1;
auto mul_acc = [&](int j, VI32& a0, VI32& a_o0, VI32& a1,
VI32& a_o1) HWY_ATTR {
int16_t s0 = cs_i16[2 * lane + j * kMaxLanes * 2];
int16_t s1 = cs_i16[2 * lane + 1 + j * kMaxLanes * 2];
int32_t s01;
hwy::CopySameSize(&s0, reinterpret_cast<int16_t*>(&s01));
hwy::CopySameSize(&s1, reinterpret_cast<int16_t*>(&s01) + 1);
VI16 sj = hn::BitCast(di16, hn::Set(di32, s01));
a0 = hn::ReorderWidenMulAccumulate(di32, vi_first8, sj, a0, a_o0);
a1 = hn::ReorderWidenMulAccumulate(di32, vi_next8, sj, a1, a_o1);
};
mul_acc(0, acc0_0, acc0_o_0, acc0_1, acc0_o_1);
if constexpr (N >= 2) mul_acc(1, acc1_0, acc1_o_0, acc1_1, acc1_o_1);
if constexpr (N >= 3) mul_acc(2, acc2_0, acc2_o_0, acc2_1, acc2_o_1);
if constexpr (N >= 4) mul_acc(3, acc3_0, acc3_o_0, acc3_1, acc3_o_1);
if constexpr (N >= 5) mul_acc(4, acc4_0, acc4_o_0, acc4_1, acc4_o_1);
if constexpr (N >= 6) mul_acc(5, acc5_0, acc5_o_0, acc5_1, acc5_o_1);
if constexpr (N >= 7) mul_acc(6, acc6_0, acc6_o_0, acc6_1, acc6_o_1);
if constexpr (N >= 8) mul_acc(7, acc7_0, acc7_o_0, acc7_1, acc7_o_1);
}
auto convert_and_add = [&](int j, VI32& a0, VI32& a_o0, VI32& a1,
VI32& a_o1) HWY_ATTR {
VF f0 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(a0, a_o0));
VF f1 = hn::ConvertTo(df, hn::RearrangeToOddPlusEven(a1, a_o1));
VF o0 = hn::Load(df, out.Row(j) + i);
VF o1 = hn::Load(df, out.Row(j) + i + NF);
VF scale_old = hn::Set(df, scales[j]);
o0 = hn::Mul(o0, scale_old);
o1 = hn::Mul(o1, scale_old);
VF scale_new = hn::Set(df, q_scales_s[j]);
o0 = hn::MulAdd(f0, scale_new, o0);
o1 = hn::MulAdd(f1, scale_new, o1);
hn::Store(o0, df, out.Row(j) + i);
hn::Store(o1, df, out.Row(j) + i + NF);
};
convert_and_add(0, acc0_0, acc0_o_0, acc0_1, acc0_o_1);
if constexpr (N >= 2)
convert_and_add(1, acc1_0, acc1_o_0, acc1_1, acc1_o_1);
if constexpr (N >= 3)
convert_and_add(2, acc2_0, acc2_o_0, acc2_1, acc2_o_1);
if constexpr (N >= 4)
convert_and_add(3, acc3_0, acc3_o_0, acc3_1, acc3_o_1);
if constexpr (N >= 5)
convert_and_add(4, acc4_0, acc4_o_0, acc4_1, acc4_o_1);
if constexpr (N >= 6)
convert_and_add(5, acc5_0, acc5_o_0, acc5_1, acc5_o_1);
if constexpr (N >= 7)
convert_and_add(6, acc6_0, acc6_o_0, acc6_1, acc6_o_1);
if constexpr (N >= 8)
convert_and_add(7, acc7_0, acc7_o_0, acc7_1, acc7_o_1);
i += 2 * NF;
}
}
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
DF df, const float* HWY_RESTRICT scales, VF c0_p0, VF c0_p1, VF c1_p0,
@ -1600,6 +1742,50 @@ MatStorageT<T> AvgPool4x4(MatStorageT<T>& input, const Allocator& allocator) {
return result;
}
// Reduces each of x and stores in following lanes of max (tested with float32)
template <class DF, typename T = hn::TFromD<DF>,
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
class VF = hn::Vec<DF>, typename F>
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
F reducer) {
const DF4 df4;
constexpr size_t kMaxLanes = hn::MaxLanes(df);
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
HWY_ALIGN T x_transposed[4 * kMaxLanes];
hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed);
VF x01 =
reducer(hn::Load(df, x_transposed), hn::Load(df, x_transposed + kLanes));
VF x23 = reducer(hn::Load(df, x_transposed + 2 * kLanes),
hn::Load(df, x_transposed + 3 * kLanes));
VF x0123 = reducer(x01, x23);
hn::Store(x0123, df, x_transposed);
VF4 result = hn::Load(df4, x_transposed);
for (int i = 1; i < kLanes / 4; ++i) {
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
}
return result;
}
// Returns vector with 8 lanes. Shouldn't be used on architectures
// with less than 8 lanes per vector.
template <class DF, typename T = hn::TFromD<DF>,
class DF8 = hn::CappedTag<T, 8>, class VF8 = hn::Vec<DF8>,
class VF = hn::Vec<DF>, typename F>
static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
VF x_5, VF x_6, VF x_7, F reducer) {
auto res0123 = Reduce4(df, x_0, x_1, x_2, x_3, reducer);
auto res4567 = Reduce4(df, x_4, x_5, x_6, x_7, reducer);
using DF4 = hn::CappedTag<T, 4>;
const DF4 df4;
const DF8 df8;
HWY_ALIGN T buf[8];
hn::Store(res0123, df4, buf);
hn::Store(res4567, df4, buf + 4);
return hn::Load(df8, buf);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -481,6 +481,16 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
}
}
template <typename T>
std::vector<MatPtrT<T>> MakeMatPtrVec(hwy::Span<const MatPtr> base) {
std::vector<MatPtrT<T>> matptrs;
matptrs.reserve(base.size());
for (auto&& mat : base) {
matptrs.emplace_back(mat);
}
return matptrs;
}
// Calls 'func' with a span of MatPtrT<T> for all elements in `base`.
// T is dynamic type, read from base. It is assumed that all elements in `base`
// have the same type.
@ -491,25 +501,17 @@ decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
for ([[maybe_unused]] auto&& mat : base) {
HWY_DASSERT(mat.GetType() == type);
}
auto make_matptr_vec = [&base](auto element) {
std::vector<MatPtrT<decltype(element)>> matptrs;
matptrs.reserve(base.size());
for (auto&& mat : base) {
matptrs.emplace_back(mat);
}
return matptrs;
};
if (type == Type::kF32) {
auto matptrs = make_matptr_vec(float{});
auto matptrs = MakeMatPtrVec<float>(base);
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);
} else if (type == Type::kBF16) {
auto matptrs = make_matptr_vec(BF16{});
auto matptrs = MakeMatPtrVec<BF16>(base);
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);
} else if (type == Type::kInt8) {
auto matptrs = make_matptr_vec(int8_t{});
auto matptrs = MakeMatPtrVec<int8_t>(base);
hwy::Span<const MatPtrT<int8_t>> matptrs_span(matptrs.data(),
matptrs.size());
return func(matptrs_span, std::forward<Args>(args)...);