VectorizedRopeAndMulBy.

~8x reduction (tested on few prompts) in Rope.
~3.8% prefill latency improvement.
~2.6% decode latency improvement.

PiperOrigin-RevId: 664650108
This commit is contained in:
Apoorv Reddy 2024-08-18 23:16:29 -07:00 committed by Copybara-Service
parent 773333e5be
commit c6eb3b6f0d
4 changed files with 100 additions and 2 deletions

View File

@ -72,11 +72,14 @@ cc_test(
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":allocator",
":common",
":gemma_lib",
":ops",
"@googletest//:gtest_main", # buildcleaner: keep
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:nanobenchmark", #buildcleaner: keep
],
)

View File

@ -228,7 +228,7 @@ class GemmaAttention {
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, kQKVDim);
} else {
RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
VectorizedRopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
}
}

View File

@ -420,6 +420,46 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos,
float* HWY_RESTRICT x_out) {
PROFILER_FUNC;
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
using D = hn::ScalableTag<float>;
using V = hn::Vec<D>;
const D d;
HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0);
for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
// Compute rotations.
V cos_theta_vec;
V sin_theta_vec;
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
// Scale input with rotations and multiply with constant.
V mul_vec = hn::Set(d, mul);
V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim));
V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv));
V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec,
hn::Mul(x1_vec, sin_theta_vec));
V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec,
hn::Mul(x1_vec, cos_theta_vec));
// Store
hn::StoreU(xout_0_vec, d, x_out + dim);
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
namespace hn = hwy::HWY_NAMESPACE;

View File

@ -38,7 +38,12 @@
#include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "ops/ops-inl.h"
#include "util/allocator.h"
#include "hwy/tests/hwy_gtest.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -361,6 +366,55 @@ void TestSigmoid() {
}
}
void TestRopeAndMulBy() {
using Config = ConfigGemma2_9B<float>;
int dim_qkv = Config::kQKVDim;
RowVectorBatch<float> x(1, dim_qkv);
std::mt19937 gen;
gen.seed(0x12345678);
std::normal_distribution<float> r{0.0, 5.0};
auto random_float = [&r, &gen] { return r(gen); };
for (int i = 0; i < dim_qkv; ++i) {
x.All()[i] = random_float();
}
const float qmul = ChooseQueryScale<Config>();
const float kmul = 1.0;
std::vector<float> qexpected(dim_qkv);
std::vector<float> qactual(dim_qkv);
std::vector<float> kexpected(dim_qkv);
std::vector<float> kactual(dim_qkv);
RowVectorBatch<float> inv_timescale =
gcpp::Activations::CreateInvTimescale<Config>();
// Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
qexpected.data());
VectorizedRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
qactual.data());
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual[i];
}
// Rope'd K embeddings
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
kexpected.data());
VectorizedRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
kactual.data());
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual[i];
}
}
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
@ -377,6 +431,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
HWY_AFTER_TEST();
} // namespace gcpp