Apply PositionalEncodingQK always in-place.

PiperOrigin-RevId: 718851803
This commit is contained in:
Daniel Keysers 2025-01-23 07:08:50 -08:00 committed by Copybara-Service
parent ce807a31a1
commit e997468496
3 changed files with 34 additions and 38 deletions

View File

@ -213,19 +213,18 @@ class GemmaAttention {
}
template <typename U>
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
const float mul, U* qk_out) {
HWY_INLINE void PositionalEncodingQK(U* qk, size_t pos, size_t layer,
const float mul) {
// qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Const();
// PostQKType::Rope
(void)layer;
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk));
Rope(qk_out, qkv_dim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, qkv_dim);
Rope(qk, qkv_dim / 2, inv_timescale, pos);
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
} else {
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out);
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
}
}
@ -315,19 +314,16 @@ class GemmaAttention {
head * qkv_dim * 2;
KVCache& kv_cache = kv_caches_[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
qkv_dim;
// Copy from `q` if MHA, or apply in-place.
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);
// If MHA, also copy V into KVCache.
// If MHA, copy computed K and V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
qkv_dim * sizeof(*kv));
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
qkv_dim;
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
}
// Apply further processing to K.
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
});
}
@ -414,7 +410,7 @@ class GemmaAttention {
// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
PositionalEncodingQK(q, pos, layer_, query_scale, q);
PositionalEncodingQK(q, pos, layer_, query_scale);
const size_t start_pos = StartPos(pos, layer_);
size_t last_pos = pos;

View File

@ -333,9 +333,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos,
float* HWY_RESTRICT x_out) {
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
PROFILER_FUNC;
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
@ -369,8 +368,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
hn::Mul(x1_vec, cos_theta_vec));
// Store
hn::StoreU(xout_0_vec, d, x_out + dim);
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
hn::StoreU(xout_0_vec, d, x + dim);
hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv);
}
// Vectorize computation for remaining dims - same as above, but with LoadN.
@ -399,8 +398,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
// Store
hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims);
hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims);
hn::StoreN(xout_0_vec, d, x + dim, remaining_dims);
hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims);
}
}

View File

@ -370,9 +370,8 @@ void TestSigmoid() {
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos,
float* HWY_RESTRICT x_out) {
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
@ -381,8 +380,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
}
}
@ -413,10 +412,11 @@ void TestRopeAndMulBy() {
// Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings
ScalarRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
qexpected.data());
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
qactual.data());
hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv);
hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv);
ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(),
pos);
RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
@ -424,10 +424,11 @@ void TestRopeAndMulBy() {
}
// Rope'd K embeddings
ScalarRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
kexpected.data());
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
kactual.data());
hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv);
hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv);
ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(),
pos);
RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)