Vectorize Rope for qkv dim not evenly divisible by number of lanes.

PiperOrigin-RevId: 665776602
This commit is contained in:
Apoorv Reddy 2024-08-21 02:21:47 -07:00 committed by Copybara-Service
parent 18e6012872
commit 48d0801fb0
2 changed files with 40 additions and 7 deletions

View File

@ -15,6 +15,7 @@
#include <stddef.h>
#include <algorithm>
#include <random>
#include <vector>
@ -81,15 +82,16 @@ TEST(OptimizeTest, GradientDescent) {
return reply;
};
// Sanity check of reply tokens.
// 1) Its length should be greater than the prompt.
// 2) The prompt should be a prefix of the reply.
auto verify = [&](const Prompt& prompt) {
auto context = prompt.context();
std::vector<int> reply = generate(context);
bool ok = true;
for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) {
if (i >= reply.size() || reply[i] != prompt.tokens[i]) {
ok = false;
}
}
ok &= (reply.size() > context.size());
ok &= std::equal(prompt.tokens.begin(), prompt.tokens.end(),
reply.begin(), reply.begin() + prompt.tokens.size());
return ok;
};

View File

@ -432,8 +432,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
using V = hn::Vec<D>;
const D d;
HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0);
for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) {
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
for (size_t dim = 0; dim < vectorizable_dims; dim += hn::Lanes(d)) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
@ -458,6 +459,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
hn::StoreU(xout_0_vec, d, x_out + dim);
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
}
// Vectorize computation for remaining dims.
size_t remaining_dims = half_dim_qkv - vectorizable_dims;
for (size_t dim = vectorizable_dims; dim < half_dim_qkv;
dim += hn::Lanes(d)) {
// Compute thetas
V pos_vec = hn::Set(d, pos);
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
// Compute rotations.
V cos_theta_vec;
V sin_theta_vec;
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
// Scale input with rotations and multiply with constant.
V mul_vec = hn::Set(d, mul);
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims));
V x1_vec =
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims));
V xout_0_vec =
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec));
V xout_1_vec =
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
// Store
hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims);
hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(