mirror of https://github.com/google/gemma.cpp.git
Vectorize Rope for qkv dim not evenly divisible by number of lanes.
PiperOrigin-RevId: 665776602
This commit is contained in:
parent
18e6012872
commit
48d0801fb0
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -81,15 +82,16 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
return reply;
|
||||
};
|
||||
|
||||
// Sanity check of reply tokens.
|
||||
// 1) Its length should be greater than the prompt.
|
||||
// 2) The prompt should be a prefix of the reply.
|
||||
auto verify = [&](const Prompt& prompt) {
|
||||
auto context = prompt.context();
|
||||
std::vector<int> reply = generate(context);
|
||||
bool ok = true;
|
||||
for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) {
|
||||
if (i >= reply.size() || reply[i] != prompt.tokens[i]) {
|
||||
ok = false;
|
||||
}
|
||||
}
|
||||
ok &= (reply.size() > context.size());
|
||||
ok &= std::equal(prompt.tokens.begin(), prompt.tokens.end(),
|
||||
reply.begin(), reply.begin() + prompt.tokens.size());
|
||||
return ok;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -432,8 +432,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
|||
using V = hn::Vec<D>;
|
||||
const D d;
|
||||
|
||||
HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0);
|
||||
for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) {
|
||||
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||
size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
|
||||
for (size_t dim = 0; dim < vectorizable_dims; dim += hn::Lanes(d)) {
|
||||
// Compute thetas
|
||||
V pos_vec = hn::Set(d, pos);
|
||||
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
||||
|
|
@ -458,6 +459,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
|||
hn::StoreU(xout_0_vec, d, x_out + dim);
|
||||
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
||||
}
|
||||
|
||||
// Vectorize computation for remaining dims.
|
||||
size_t remaining_dims = half_dim_qkv - vectorizable_dims;
|
||||
for (size_t dim = vectorizable_dims; dim < half_dim_qkv;
|
||||
dim += hn::Lanes(d)) {
|
||||
// Compute thetas
|
||||
V pos_vec = hn::Set(d, pos);
|
||||
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
|
||||
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
||||
|
||||
// Compute rotations.
|
||||
V cos_theta_vec;
|
||||
V sin_theta_vec;
|
||||
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
||||
|
||||
// Scale input with rotations and multiply with constant.
|
||||
V mul_vec = hn::Set(d, mul);
|
||||
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims));
|
||||
V x1_vec =
|
||||
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims));
|
||||
|
||||
V xout_0_vec =
|
||||
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec));
|
||||
V xout_1_vec =
|
||||
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
|
||||
|
||||
// Store
|
||||
hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims);
|
||||
hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims);
|
||||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||
|
|
|
|||
Loading…
Reference in New Issue