mirror of https://github.com/google/gemma.cpp.git
parent
56fa6e4839
commit
7e5310b908
|
|
@ -446,7 +446,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
|||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||
F reducer) {
|
||||
const DF4 df4;
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
|
|
@ -469,7 +469,7 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
|||
|
||||
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
|
||||
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
|
||||
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
||||
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
||||
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
|
||||
|
|
|
|||
Loading…
Reference in New Issue