From c6eb3b6f0df453acad42145b6bb7e8eca48cb55e Mon Sep 17 00:00:00 2001 From: Apoorv Reddy Date: Sun, 18 Aug 2024 23:16:29 -0700 Subject: [PATCH] VectorizedRopeAndMulBy. ~8x reduction (tested on few prompts) in Rope. ~3.8% prefill latency improvement. ~2.6% decode latency improvement. PiperOrigin-RevId: 664650108 --- BUILD.bazel | 5 ++++- gemma/gemma-inl.h | 2 +- ops/ops-inl.h | 40 ++++++++++++++++++++++++++++++++++ ops/ops_test.cc | 55 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index a6a9534..8c8fc27 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -72,11 +72,14 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ + ":allocator", + ":common", + ":gemma_lib", ":ops", "@googletest//:gtest_main", # buildcleaner: keep "@hwy//:hwy", "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", + "@hwy//:nanobenchmark", #buildcleaner: keep ], ) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 86a8193..0f14790 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -228,7 +228,7 @@ class GemmaAttention { Rope(qk_out, kQKVDim / 2, inv_timescale, pos); MulByConst(mul, qk_out, kQKVDim); } else { - RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out); + VectorizedRopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out); } } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 9a8e7c0..2a1da98 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -420,6 +420,46 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( } } +static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy( + const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, int pos, + float* HWY_RESTRICT x_out) { + PROFILER_FUNC; + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + + using D = hn::ScalableTag; + using V = hn::Vec; + const D d; + + HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0); + for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) { + // Compute thetas + V pos_vec = hn::Set(d, pos); + V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim); + V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec); + + // Compute rotations. + V cos_theta_vec; + V sin_theta_vec; + hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec); + + // Scale input with rotations and multiply with constant. + V mul_vec = hn::Set(d, mul); + V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim)); + V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv)); + + V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec, + hn::Mul(x1_vec, sin_theta_vec)); + V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec, + hn::Mul(x1_vec, cos_theta_vec)); + + // Store + hn::StoreU(xout_0_vec, d, x_out + dim); + hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv); + } +} + static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 0a66f13..a8d2fee 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -38,7 +38,12 @@ #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" // After highway.h +#include "gemma/activations.h" +#include "gemma/common.h" +#include "gemma/configs.h" #include "ops/ops-inl.h" +#include "util/allocator.h" +#include "hwy/tests/hwy_gtest.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -361,6 +366,55 @@ void TestSigmoid() { } } +void TestRopeAndMulBy() { + using Config = ConfigGemma2_9B; + int dim_qkv = Config::kQKVDim; + RowVectorBatch x(1, dim_qkv); + + std::mt19937 gen; + gen.seed(0x12345678); + std::normal_distribution r{0.0, 5.0}; + auto random_float = [&r, &gen] { return r(gen); }; + + for (int i = 0; i < dim_qkv; ++i) { + x.All()[i] = random_float(); + } + + const float qmul = ChooseQueryScale(); + const float kmul = 1.0; + + std::vector qexpected(dim_qkv); + std::vector qactual(dim_qkv); + std::vector kexpected(dim_qkv); + std::vector kactual(dim_qkv); + RowVectorBatch inv_timescale = + gcpp::Activations::CreateInvTimescale(); + // Assert VectorizedRope computation is same as regular rope at different pos. + for (int pos = 1; pos < 500; pos++) { + // Rope'd Q embeddings + RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, + qexpected.data()); + VectorizedRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, + qactual.data()); + + for (int i = 0; i < dim_qkv; ++i) { + EXPECT_NEAR(qactual[i], qexpected[i], 1e-4) + << "qIndex:" << i << "qInput:" << qactual[i]; + } + + // Rope'd K embeddings + RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, + kexpected.data()); + VectorizedRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, + kactual.data()); + + for (int i = 0; i < dim_qkv; ++i) { + EXPECT_NEAR(kactual[i], kexpected[i], 1e-4) + << "kIndex:" << i << "kInput:" << kactual[i]; + } + } +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -377,6 +431,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); +HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); HWY_AFTER_TEST(); } // namespace gcpp