mirror of https://github.com/google/gemma.cpp.git
VectorizedRopeAndMulBy.
~8x reduction (tested on few prompts) in Rope. ~3.8% prefill latency improvement. ~2.6% decode latency improvement. PiperOrigin-RevId: 664650108
This commit is contained in:
parent
773333e5be
commit
c6eb3b6f0d
|
|
@ -72,11 +72,14 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":common",
|
||||||
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark", #buildcleaner: keep
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -228,7 +228,7 @@ class GemmaAttention {
|
||||||
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
|
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
|
||||||
MulByConst(mul, qk_out, kQKVDim);
|
MulByConst(mul, qk_out, kQKVDim);
|
||||||
} else {
|
} else {
|
||||||
RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
|
VectorizedRopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -420,6 +420,46 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
||||||
|
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
|
const float* HWY_RESTRICT inv_timescale, int pos,
|
||||||
|
float* HWY_RESTRICT x_out) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
const D d;
|
||||||
|
|
||||||
|
HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0);
|
||||||
|
for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) {
|
||||||
|
// Compute thetas
|
||||||
|
V pos_vec = hn::Set(d, pos);
|
||||||
|
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
||||||
|
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
||||||
|
|
||||||
|
// Compute rotations.
|
||||||
|
V cos_theta_vec;
|
||||||
|
V sin_theta_vec;
|
||||||
|
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
||||||
|
|
||||||
|
// Scale input with rotations and multiply with constant.
|
||||||
|
V mul_vec = hn::Set(d, mul);
|
||||||
|
V x0_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim));
|
||||||
|
V x1_vec = hn::Mul(mul_vec, hn::LoadU(d, x + dim + half_dim_qkv));
|
||||||
|
|
||||||
|
V xout_0_vec = hn::MulSub(x0_vec, cos_theta_vec,
|
||||||
|
hn::Mul(x1_vec, sin_theta_vec));
|
||||||
|
V xout_1_vec = hn::MulAdd(x0_vec, sin_theta_vec,
|
||||||
|
hn::Mul(x1_vec, cos_theta_vec));
|
||||||
|
|
||||||
|
// Store
|
||||||
|
hn::StoreU(xout_0_vec, d, x_out + dim);
|
||||||
|
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,12 @@
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -361,6 +366,55 @@ void TestSigmoid() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestRopeAndMulBy() {
|
||||||
|
using Config = ConfigGemma2_9B<float>;
|
||||||
|
int dim_qkv = Config::kQKVDim;
|
||||||
|
RowVectorBatch<float> x(1, dim_qkv);
|
||||||
|
|
||||||
|
std::mt19937 gen;
|
||||||
|
gen.seed(0x12345678);
|
||||||
|
std::normal_distribution<float> r{0.0, 5.0};
|
||||||
|
auto random_float = [&r, &gen] { return r(gen); };
|
||||||
|
|
||||||
|
for (int i = 0; i < dim_qkv; ++i) {
|
||||||
|
x.All()[i] = random_float();
|
||||||
|
}
|
||||||
|
|
||||||
|
const float qmul = ChooseQueryScale<Config>();
|
||||||
|
const float kmul = 1.0;
|
||||||
|
|
||||||
|
std::vector<float> qexpected(dim_qkv);
|
||||||
|
std::vector<float> qactual(dim_qkv);
|
||||||
|
std::vector<float> kexpected(dim_qkv);
|
||||||
|
std::vector<float> kactual(dim_qkv);
|
||||||
|
RowVectorBatch<float> inv_timescale =
|
||||||
|
gcpp::Activations::CreateInvTimescale<Config>();
|
||||||
|
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||||
|
for (int pos = 1; pos < 500; pos++) {
|
||||||
|
// Rope'd Q embeddings
|
||||||
|
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
|
qexpected.data());
|
||||||
|
VectorizedRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
|
qactual.data());
|
||||||
|
|
||||||
|
for (int i = 0; i < dim_qkv; ++i) {
|
||||||
|
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
|
||||||
|
<< "qIndex:" << i << "qInput:" << qactual[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rope'd K embeddings
|
||||||
|
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
|
kexpected.data());
|
||||||
|
VectorizedRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
|
kactual.data());
|
||||||
|
|
||||||
|
for (int i = 0; i < dim_qkv; ++i) {
|
||||||
|
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
|
||||||
|
<< "kIndex:" << i << "kInput:" << kactual[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -377,6 +431,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue