mirror of https://github.com/google/gemma.cpp.git
Change to use faster exponent function
PiperOrigin-RevId: 877981568
This commit is contained in:
parent
49cb438b1e
commit
539d9bb8e7
|
|
@ -633,6 +633,7 @@ cc_library(
|
|||
"//compression:compress",
|
||||
"//compression:types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:math",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5 EXCLUDE_FROM_ALL)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98 EXCLUDE_FROM_ALL)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||
|
|
|
|||
17
MODULE.bazel
17
MODULE.bazel
|
|
@ -3,22 +3,23 @@ module(
|
|||
version = "0.1.0",
|
||||
)
|
||||
|
||||
bazel_dep(name = "abseil-cpp", version = "20240722.0")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.7.1")
|
||||
bazel_dep(name = "googletest", version = "1.15.2")
|
||||
bazel_dep(name = "abseil-cpp", version = "20250814.1")
|
||||
bazel_dep(name = "bazel_skylib", version = "1.8.1")
|
||||
bazel_dep(name = "googletest", version = "1.17.0")
|
||||
bazel_dep(name = "highway", version = "1.1.0")
|
||||
bazel_dep(name = "nlohmann_json", version = "3.11.3")
|
||||
bazel_dep(name = "platforms", version = "0.0.10")
|
||||
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.16")
|
||||
bazel_dep(name = "protobuf", version = "33.4")
|
||||
bazel_dep(name = "platforms", version = "1.0.0")
|
||||
bazel_dep(name = "pybind11_bazel", version = "2.13.6")
|
||||
bazel_dep(name = "rules_cc", version = "0.2.0")
|
||||
bazel_dep(name = "rules_license", version = "1.0.0")
|
||||
bazel_dep(name = "rules_python", version = "1.0.0")
|
||||
bazel_dep(name = "rules_python", version = "1.6.3")
|
||||
bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||
|
||||
# Require a more recent version.
|
||||
git_override(
|
||||
module_name = "highway",
|
||||
commit = "3b680cde3a556bead9cc23c8f595d07a44d5a0d5",
|
||||
commit = "c971dbe61bd2751923e3458666450bf95dfbbd98",
|
||||
remote = "https://github.com/google/highway",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -451,7 +451,7 @@ FetchContent_MakeAvailable(sentencepiece)
|
|||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main)
|
||||
FetchContent_MakeAvailable(gemma)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5)
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
||||
FetchContent_MakeAvailable(sentencepiece)
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@
|
|||
#include "gemma/attention.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "hwy/contrib/math/fast_math-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -640,25 +641,22 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
|||
new_max = hn::Max(new_max, old_max_vf);
|
||||
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
|
||||
hn::StoreU(new_max, df4, old_max);
|
||||
auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
|
||||
const VF new_max_i = hn::Set(df, old_max[i]);
|
||||
x_p0 = hn::FastExpMinusOrZero(df, hn::Sub(x_p0, new_max_i));
|
||||
x_p1 = hn::FastExpMinusOrZero(df, hn::Sub(x_p1, new_max_i));
|
||||
};
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[0]);
|
||||
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
|
||||
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
|
||||
apply_exp(0, x_0_p0, x_0_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[1]);
|
||||
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
|
||||
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
|
||||
apply_exp(1, x_1_p0, x_1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[2]);
|
||||
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
|
||||
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
|
||||
apply_exp(2, x_2_p0, x_2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[3]);
|
||||
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
|
||||
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
|
||||
apply_exp(3, x_3_p0, x_3_p1);
|
||||
}
|
||||
VF4 old_d_vf = hn::Set(df4, 0.0f);
|
||||
old_d_vf = hn::LoadU(df4, old_d);
|
||||
|
|
@ -709,10 +707,6 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
|||
}
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_NOINLINE VF CallExp(DF df, VF x_p0) {
|
||||
return hn::Exp(df, x_p0);
|
||||
}
|
||||
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
||||
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
|
||||
|
|
@ -766,45 +760,36 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
|||
new_max = hn::Max(new_max, old_max_vf);
|
||||
auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf));
|
||||
hn::StoreU(new_max, df8, old_max);
|
||||
|
||||
auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR {
|
||||
const VF new_max_i = hn::Set(df, old_max[i]);
|
||||
x_p0 = hn::FastExpMinusOrZero(df, hn::Sub(x_p0, new_max_i));
|
||||
x_p1 = hn::FastExpMinusOrZero(df, hn::Sub(x_p1, new_max_i));
|
||||
};
|
||||
|
||||
if constexpr (kNumQueries >= 1) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[0]);
|
||||
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
|
||||
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
|
||||
apply_exp(0, x_0_p0, x_0_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 2) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[1]);
|
||||
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
|
||||
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
|
||||
apply_exp(1, x_1_p0, x_1_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 3) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[2]);
|
||||
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
|
||||
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
|
||||
apply_exp(2, x_2_p0, x_2_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 4) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[3]);
|
||||
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
|
||||
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
|
||||
apply_exp(3, x_3_p0, x_3_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 5) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[4]);
|
||||
x_4_p0 = hn::CallExp(df, hn::Sub(x_4_p0, new_max_0));
|
||||
x_4_p1 = hn::CallExp(df, hn::Sub(x_4_p1, new_max_0));
|
||||
apply_exp(4, x_4_p0, x_4_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 6) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[5]);
|
||||
x_5_p0 = hn::CallExp(df, hn::Sub(x_5_p0, new_max_0));
|
||||
x_5_p1 = hn::CallExp(df, hn::Sub(x_5_p1, new_max_0));
|
||||
apply_exp(5, x_5_p0, x_5_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 7) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[6]);
|
||||
x_6_p0 = hn::CallExp(df, hn::Sub(x_6_p0, new_max_0));
|
||||
x_6_p1 = hn::CallExp(df, hn::Sub(x_6_p1, new_max_0));
|
||||
apply_exp(6, x_6_p0, x_6_p1);
|
||||
}
|
||||
if constexpr (kNumQueries >= 8) {
|
||||
const VF new_max_0 = hn::Set(df, old_max[7]);
|
||||
x_7_p0 = hn::CallExp(df, hn::Sub(x_7_p0, new_max_0));
|
||||
x_7_p1 = hn::CallExp(df, hn::Sub(x_7_p1, new_max_0));
|
||||
apply_exp(7, x_7_p0, x_7_p1);
|
||||
}
|
||||
VF8 old_d_vf = hn::Set(df8, 0.0f);
|
||||
old_d_vf = hn::LoadU(df8, old_d);
|
||||
|
|
|
|||
|
|
@ -366,11 +366,11 @@ void TestTiledFlashAttention() {
|
|||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f)
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-3f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f);
|
||||
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-5f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -481,7 +481,7 @@ void TestTiledFlashAttentionBF16() {
|
|||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::cerr << "exp_d: " << exp_denominator_sums[i]
|
||||
<< " max_logit: " << max_logits[i] << std::endl;
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
|
||||
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f)
|
||||
<< "i=" << i;
|
||||
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
|
||||
for (int j = 0; j < qkv_dim; ++j) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue