From 539d9bb8e77d4f591d79dbd0b3ea14956783b962 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Tue, 3 Mar 2026 09:15:42 -0800 Subject: [PATCH] Change to use faster exponent function PiperOrigin-RevId: 877981568 --- BUILD.bazel | 1 + CMakeLists.txt | 2 +- MODULE.bazel | 17 ++++--- README.md | 2 +- examples/hello_world/CMakeLists.txt | 2 +- examples/simplified_gemma/CMakeLists.txt | 2 +- gemma/flash_attention.cc | 65 +++++++++--------------- gemma/flash_attention_test.cc | 6 +-- 8 files changed, 42 insertions(+), 55 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 5f3bf87..cfd7e57 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -633,6 +633,7 @@ cc_library( "//compression:compress", "//compression:types", "@highway//:hwy", + "@highway//:math", "@highway//:profiler", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 58bfab5..000ef99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5 EXCLUDE_FROM_ALL) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98 EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(highway) ## Note: absl needs to be installed by sentencepiece. This will only happen if diff --git a/MODULE.bazel b/MODULE.bazel index d60c0f4..0dea775 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -3,22 +3,23 @@ module( version = "0.1.0", ) -bazel_dep(name = "abseil-cpp", version = "20240722.0") -bazel_dep(name = "bazel_skylib", version = "1.7.1") -bazel_dep(name = "googletest", version = "1.15.2") +bazel_dep(name = "abseil-cpp", version = "20250814.1") +bazel_dep(name = "bazel_skylib", version = "1.8.1") +bazel_dep(name = "googletest", version = "1.17.0") bazel_dep(name = "highway", version = "1.1.0") bazel_dep(name = "nlohmann_json", version = "3.11.3") -bazel_dep(name = "platforms", version = "0.0.10") -bazel_dep(name = "pybind11_bazel", version = "2.12.0") -bazel_dep(name = "rules_cc", version = "0.0.16") +bazel_dep(name = "protobuf", version = "33.4") +bazel_dep(name = "platforms", version = "1.0.0") +bazel_dep(name = "pybind11_bazel", version = "2.13.6") +bazel_dep(name = "rules_cc", version = "0.2.0") bazel_dep(name = "rules_license", version = "1.0.0") -bazel_dep(name = "rules_python", version = "1.0.0") +bazel_dep(name = "rules_python", version = "1.6.3") bazel_dep(name = "google_benchmark", version = "1.8.5") # Require a more recent version. git_override( module_name = "highway", - commit = "3b680cde3a556bead9cc23c8f595d07a44d5a0d5", + commit = "c971dbe61bd2751923e3458666450bf95dfbbd98", remote = "https://github.com/google/highway", ) diff --git a/README.md b/README.md index 0aedf38..b740f66 100644 --- a/README.md +++ b/README.md @@ -451,7 +451,7 @@ FetchContent_MakeAvailable(sentencepiece) FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp GIT_TAG origin/main) FetchContent_MakeAvailable(gemma) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98) FetchContent_MakeAvailable(highway) ``` diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 2fd94b4..6f1c090 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) diff --git a/examples/simplified_gemma/CMakeLists.txt b/examples/simplified_gemma/CMakeLists.txt index ca2e405..5b59f1b 100644 --- a/examples/simplified_gemma/CMakeLists.txt +++ b/examples/simplified_gemma/CMakeLists.txt @@ -18,7 +18,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 3b680cde3a556bead9cc23c8f595d07a44d5a0d5) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c971dbe61bd2751923e3458666450bf95dfbbd98) FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 360dc9d..f8df7e0 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -54,6 +54,7 @@ #include "gemma/attention.h" #include "ops/matmul-inl.h" #include "ops/ops-inl.h" +#include "hwy/contrib/math/fast_math-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -640,25 +641,22 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( new_max = hn::Max(new_max, old_max_vf); auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf)); hn::StoreU(new_max, df4, old_max); + auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR { + const VF new_max_i = hn::Set(df, old_max[i]); + x_p0 = hn::FastExpMinusOrZero(df, hn::Sub(x_p0, new_max_i)); + x_p1 = hn::FastExpMinusOrZero(df, hn::Sub(x_p1, new_max_i)); + }; if constexpr (kNumQueries >= 1) { - const VF new_max_0 = hn::Set(df, old_max[0]); - x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0)); - x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0)); + apply_exp(0, x_0_p0, x_0_p1); } if constexpr (kNumQueries >= 2) { - const VF new_max_0 = hn::Set(df, old_max[1]); - x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0)); - x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0)); + apply_exp(1, x_1_p0, x_1_p1); } if constexpr (kNumQueries >= 3) { - const VF new_max_0 = hn::Set(df, old_max[2]); - x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0)); - x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0)); + apply_exp(2, x_2_p0, x_2_p1); } if constexpr (kNumQueries >= 4) { - const VF new_max_0 = hn::Set(df, old_max[3]); - x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0)); - x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0)); + apply_exp(3, x_3_p0, x_3_p1); } VF4 old_d_vf = hn::Set(df4, 0.0f); old_d_vf = hn::LoadU(df4, old_d); @@ -709,10 +707,6 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4( } } -template > -HWY_NOINLINE VF CallExp(DF df, VF x_p0) { - return hn::Exp(df, x_p0); -} template > static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1, @@ -766,45 +760,36 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8( new_max = hn::Max(new_max, old_max_vf); auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf)); hn::StoreU(new_max, df8, old_max); + + auto apply_exp = [&](int i, VF& x_p0, VF& x_p1) HWY_ATTR { + const VF new_max_i = hn::Set(df, old_max[i]); + x_p0 = hn::FastExpMinusOrZero(df, hn::Sub(x_p0, new_max_i)); + x_p1 = hn::FastExpMinusOrZero(df, hn::Sub(x_p1, new_max_i)); + }; + if constexpr (kNumQueries >= 1) { - const VF new_max_0 = hn::Set(df, old_max[0]); - x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0)); - x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0)); + apply_exp(0, x_0_p0, x_0_p1); } if constexpr (kNumQueries >= 2) { - const VF new_max_0 = hn::Set(df, old_max[1]); - x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0)); - x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0)); + apply_exp(1, x_1_p0, x_1_p1); } if constexpr (kNumQueries >= 3) { - const VF new_max_0 = hn::Set(df, old_max[2]); - x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0)); - x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0)); + apply_exp(2, x_2_p0, x_2_p1); } if constexpr (kNumQueries >= 4) { - const VF new_max_0 = hn::Set(df, old_max[3]); - x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0)); - x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0)); + apply_exp(3, x_3_p0, x_3_p1); } if constexpr (kNumQueries >= 5) { - const VF new_max_0 = hn::Set(df, old_max[4]); - x_4_p0 = hn::CallExp(df, hn::Sub(x_4_p0, new_max_0)); - x_4_p1 = hn::CallExp(df, hn::Sub(x_4_p1, new_max_0)); + apply_exp(4, x_4_p0, x_4_p1); } if constexpr (kNumQueries >= 6) { - const VF new_max_0 = hn::Set(df, old_max[5]); - x_5_p0 = hn::CallExp(df, hn::Sub(x_5_p0, new_max_0)); - x_5_p1 = hn::CallExp(df, hn::Sub(x_5_p1, new_max_0)); + apply_exp(5, x_5_p0, x_5_p1); } if constexpr (kNumQueries >= 7) { - const VF new_max_0 = hn::Set(df, old_max[6]); - x_6_p0 = hn::CallExp(df, hn::Sub(x_6_p0, new_max_0)); - x_6_p1 = hn::CallExp(df, hn::Sub(x_6_p1, new_max_0)); + apply_exp(6, x_6_p0, x_6_p1); } if constexpr (kNumQueries >= 8) { - const VF new_max_0 = hn::Set(df, old_max[7]); - x_7_p0 = hn::CallExp(df, hn::Sub(x_7_p0, new_max_0)); - x_7_p1 = hn::CallExp(df, hn::Sub(x_7_p1, new_max_0)); + apply_exp(7, x_7_p0, x_7_p1); } VF8 old_d_vf = hn::Set(df8, 0.0f); old_d_vf = hn::LoadU(df8, old_d); diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 782e613..fd693d9 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -366,11 +366,11 @@ void TestTiledFlashAttention() { for (int i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-4f) + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-3f) << "i=" << i; EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i; for (int j = 0; j < qkv_dim; ++j) { - EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-6f); + EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-5f); } } } @@ -481,7 +481,7 @@ void TestTiledFlashAttentionBF16() { for (int i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f) + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f) << "i=" << i; EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; for (int j = 0; j < qkv_dim; ++j) {