mirror of https://github.com/google/gemma.cpp.git
parent
56fa6e4839
commit
7e5310b908
|
|
@ -446,7 +446,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
template <class DF, typename T = hn::TFromD<DF>,
|
template <class DF, typename T = hn::TFromD<DF>,
|
||||||
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
||||||
class VF = hn::Vec<DF>, typename F>
|
class VF = hn::Vec<DF>, typename F>
|
||||||
static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||||
F reducer) {
|
F reducer) {
|
||||||
const DF4 df4;
|
const DF4 df4;
|
||||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||||
|
|
@ -469,7 +469,7 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||||
|
|
||||||
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
|
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
|
||||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||||
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
|
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
|
||||||
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
||||||
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
||||||
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
|
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue