mirror of https://github.com/google/gemma.cpp.git
Vectorize Rope for qkv dim not evenly divisible by number of lanes.
PiperOrigin-RevId: 665776602
This commit is contained in:
parent
18e6012872
commit
48d0801fb0
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -81,15 +82,16 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
return reply;
|
return reply;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Sanity check of reply tokens.
|
||||||
|
// 1) Its length should be greater than the prompt.
|
||||||
|
// 2) The prompt should be a prefix of the reply.
|
||||||
auto verify = [&](const Prompt& prompt) {
|
auto verify = [&](const Prompt& prompt) {
|
||||||
auto context = prompt.context();
|
auto context = prompt.context();
|
||||||
std::vector<int> reply = generate(context);
|
std::vector<int> reply = generate(context);
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) {
|
ok &= (reply.size() > context.size());
|
||||||
if (i >= reply.size() || reply[i] != prompt.tokens[i]) {
|
ok &= std::equal(prompt.tokens.begin(), prompt.tokens.end(),
|
||||||
ok = false;
|
reply.begin(), reply.begin() + prompt.tokens.size());
|
||||||
}
|
|
||||||
}
|
|
||||||
return ok;
|
return ok;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -432,8 +432,9 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
||||||
using V = hn::Vec<D>;
|
using V = hn::Vec<D>;
|
||||||
const D d;
|
const D d;
|
||||||
|
|
||||||
HWY_DASSERT(half_dim_qkv % hn::Lanes(d) == 0);
|
// Vectorize computation for half_dim_qkv - (half_dim_qkv % Lanes)
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; dim += hn::Lanes(d)) {
|
size_t vectorizable_dims = hwy::RoundDownTo(half_dim_qkv, hn::Lanes(d));
|
||||||
|
for (size_t dim = 0; dim < vectorizable_dims; dim += hn::Lanes(d)) {
|
||||||
// Compute thetas
|
// Compute thetas
|
||||||
V pos_vec = hn::Set(d, pos);
|
V pos_vec = hn::Set(d, pos);
|
||||||
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
V inv_time_scale_vec = hn::LoadU(d, inv_timescale + dim);
|
||||||
|
|
@ -458,6 +459,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
||||||
hn::StoreU(xout_0_vec, d, x_out + dim);
|
hn::StoreU(xout_0_vec, d, x_out + dim);
|
||||||
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Vectorize computation for remaining dims.
|
||||||
|
size_t remaining_dims = half_dim_qkv - vectorizable_dims;
|
||||||
|
for (size_t dim = vectorizable_dims; dim < half_dim_qkv;
|
||||||
|
dim += hn::Lanes(d)) {
|
||||||
|
// Compute thetas
|
||||||
|
V pos_vec = hn::Set(d, pos);
|
||||||
|
V inv_time_scale_vec = hn::LoadN(d, inv_timescale + dim, remaining_dims);
|
||||||
|
V theta_vec = hn::Mul(pos_vec, inv_time_scale_vec);
|
||||||
|
|
||||||
|
// Compute rotations.
|
||||||
|
V cos_theta_vec;
|
||||||
|
V sin_theta_vec;
|
||||||
|
hn::SinCos(d, theta_vec, sin_theta_vec, cos_theta_vec);
|
||||||
|
|
||||||
|
// Scale input with rotations and multiply with constant.
|
||||||
|
V mul_vec = hn::Set(d, mul);
|
||||||
|
V x0_vec = hn::Mul(mul_vec, hn::LoadN(d, x + dim, remaining_dims));
|
||||||
|
V x1_vec =
|
||||||
|
hn::Mul(mul_vec, hn::LoadN(d, x + dim + half_dim_qkv, remaining_dims));
|
||||||
|
|
||||||
|
V xout_0_vec =
|
||||||
|
hn::MulSub(x0_vec, cos_theta_vec, hn::Mul(x1_vec, sin_theta_vec));
|
||||||
|
V xout_1_vec =
|
||||||
|
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
|
||||||
|
|
||||||
|
// Store
|
||||||
|
hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims);
|
||||||
|
hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue