mirror of https://github.com/google/gemma.cpp.git
parent
0c64987a96
commit
44dfd69b9b
|
|
@ -142,6 +142,7 @@ cc_test(
|
|||
":mat",
|
||||
":matmul",
|
||||
":query",
|
||||
":test_util",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@
|
|||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "gemma/flash_structs.h"
|
||||
|
|
@ -438,6 +440,152 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
|||
return scale;
|
||||
}
|
||||
|
||||
// Reduces each of x and stores in following lanes of max (tested with float32)
|
||||
template <class DF, typename T = hn::TFromD<DF>,
|
||||
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
|
||||
class VF = hn::Vec<DF>, typename F>
|
||||
static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
|
||||
F reducer) {
|
||||
const DF4 df4;
|
||||
constexpr size_t kMaxLanes = hn::MaxLanes(df);
|
||||
HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df);
|
||||
HWY_ALIGN T x_transposed[4 * kMaxLanes];
|
||||
hn::StoreInterleaved4<DF>(x_0, x_1, x_2, x_3, df, x_transposed);
|
||||
VF4 result = hn::Load(df4, x_transposed);
|
||||
for (int i = 1; i < kLanes; ++i) {
|
||||
result = reducer(result, hn::Load(df4, x_transposed + i * 4));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
|
||||
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
||||
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
|
||||
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
|
||||
float* HWY_RESTRICT scales) {
|
||||
using DF4 = hn::CappedTag<float, 4>;
|
||||
const DF4 df4;
|
||||
using VF4 = hn::Vec<DF4>;
|
||||
static_assert(kNumQueries >= 1 && kNumQueries <= 4);
|
||||
VF4 new_max = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
|
||||
VF max_0, max_1, max_2, max_3 = hn::Zero(df);
|
||||
max_0 = hn::Max(x_0_p0, x_0_p1);
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
max_1 = hn::Max(x_1_p0, x_1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
max_2 = hn::Max(x_2_p0, x_2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
max_3 = hn::Max(x_3_p0, x_3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries == 1) {
|
||||
new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0));
|
||||
} else {
|
||||
new_max = Reduce4(df, max_0, max_1, max_2, max_3,
|
||||
[](auto a, auto b) { return hn::Max(a, b); });
|
||||
}
|
||||
if (att_cap > 0.0f) {
|
||||
VF4 cap = hn::Set(df4, att_cap);
|
||||
VF4 one_over_cap = hn::Set(df4, one_over_att_cap);
|
||||
new_max = hn::Mul(cap, hn::Tanh(df4, hn::Mul(new_max, one_over_cap)));
|
||||
}
|
||||
VF4 old_max_vf = hn::Set(df4, -std::numeric_limits<float>::max() / 2.0f);
|
||||
old_max_vf = hn::LoadU(df4, old_max);
|
||||
new_max = hn::Max(new_max, old_max_vf);
|
||||
// TODO figure out what was wrong with broadcasts and change to that.
|
||||
HWY_ALIGN float tmp_max[4];
|
||||
hn::Store(new_max, df4, tmp_max);
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[0]);
|
||||
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0 , new_max_0));
|
||||
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[1]);
|
||||
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
|
||||
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[2]);
|
||||
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
|
||||
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
const VF new_max_0 = hn::Set(df, tmp_max[3]);
|
||||
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
|
||||
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
|
||||
}
|
||||
VF4 old_d_vf = hn::Set(df4, 0.0f);
|
||||
old_d_vf = hn::LoadU(df4, old_d);
|
||||
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
|
||||
|
||||
hn::StoreU(new_max, df4, old_max);
|
||||
|
||||
VF4 x_sum = hn::Zero(df4);
|
||||
if constexpr (kNumQueries == 1) {
|
||||
x_sum = hn::Set(df4, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
|
||||
} else {
|
||||
VF x_0_sum = hn::Add(x_0_p0, x_0_p1);
|
||||
VF x_1_sum = hn::Add(x_1_p0, x_1_p1);
|
||||
VF x_2_sum = hn::Add(x_2_p0, x_2_p1);
|
||||
VF x_3_sum = hn::Add(x_3_p0, x_3_p1);
|
||||
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
|
||||
[](auto a, auto b) { return hn::Add(a, b); });
|
||||
}
|
||||
old_d_vf = hn::Add(scale, x_sum);
|
||||
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
|
||||
const VF zero = hn::Zero(df);
|
||||
const VF4 zero4 = hn::Zero(df4);
|
||||
const VF4 one_over_d =
|
||||
hn::MaskedDivOr(zero4, non_zero_mask, hn::Set(df4, 1.0f), old_d_vf);
|
||||
float tmp_one_over_d[4];
|
||||
hn::Store(one_over_d, df4, tmp_one_over_d);
|
||||
hn::Store(old_d_vf, df4, old_d);
|
||||
scale = hn::Mul(scale, one_over_d);
|
||||
hn::Store(scale, df4, scales);
|
||||
if (hn::ExtractLane(old_d_vf, 0) > 0.0f) {
|
||||
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
|
||||
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
|
||||
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
|
||||
} else {
|
||||
x_0_p0 = zero;
|
||||
x_0_p1 = zero;
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
if (hn::ExtractLane(old_d_vf, 1) > 0.0f) {
|
||||
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
|
||||
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
|
||||
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
|
||||
} else {
|
||||
x_1_p0 = zero;
|
||||
x_1_p1 = zero;
|
||||
}
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
if (hn::ExtractLane(old_d_vf, 2) > 0.0f) {
|
||||
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
|
||||
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
|
||||
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
|
||||
} else {
|
||||
x_2_p0 = zero;
|
||||
x_2_p1 = zero;
|
||||
}
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
if (hn::ExtractLane(old_d_vf, 3) > 0.0f) {
|
||||
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
|
||||
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
|
||||
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
|
||||
} else {
|
||||
x_3_p0 = zero;
|
||||
x_3_p1 = zero;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Implements flash attention for a strip of 4 query vectors.
|
||||
// It iterates through timesteps in K from `start_pos` up to `max_last_pos`.
|
||||
// Timesteps up to `min_last_pos` (*) are processed in tiles of shape 4 Q rows
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -24,6 +26,7 @@
|
|||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/test_util.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
|
|
|||
15
util/mat.h
15
util/mat.h
|
|
@ -454,6 +454,21 @@ decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
|
|||
}
|
||||
}
|
||||
|
||||
// Like CallUpcasted, but only for kv_cache types: kBF16 and kF32.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
if (base->GetType() == Type::kF32) {
|
||||
const MatPtrT<float> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else if (base->GetType() == Type::kBF16) {
|
||||
const MatPtrT<BF16> mat(*base);
|
||||
return func(&mat, std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||
void ZeroInit(MatPtr& mat);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue