mirror of https://github.com/google/gemma.cpp.git
Apply PositionalEncodingQK always in-place.
PiperOrigin-RevId: 718851803
This commit is contained in:
parent
ce807a31a1
commit
e997468496
|
|
@ -213,19 +213,18 @@ class GemmaAttention {
|
|||
}
|
||||
|
||||
template <typename U>
|
||||
HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer,
|
||||
const float mul, U* qk_out) {
|
||||
HWY_INLINE void PositionalEncodingQK(U* qk, size_t pos, size_t layer,
|
||||
const float mul) {
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const float* inv_timescale = activations_.inv_timescale.Const();
|
||||
// PostQKType::Rope
|
||||
(void)layer;
|
||||
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
|
||||
hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk));
|
||||
Rope(qk_out, qkv_dim / 2, inv_timescale, pos);
|
||||
MulByConst(mul, qk_out, qkv_dim);
|
||||
Rope(qk, qkv_dim / 2, inv_timescale, pos);
|
||||
if (mul != 1.0f) MulByConst(mul, qk, qkv_dim);
|
||||
} else {
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out);
|
||||
RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -315,19 +314,16 @@ class GemmaAttention {
|
|||
head * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches_[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
const float* HWY_RESTRICT mha_kv =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||
qkv_dim;
|
||||
|
||||
// Copy from `q` if MHA, or apply in-place.
|
||||
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
|
||||
kv);
|
||||
|
||||
// If MHA, also copy V into KVCache.
|
||||
// If MHA, copy computed K and V into KVCache.
|
||||
if (is_mha_) {
|
||||
hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim,
|
||||
qkv_dim * sizeof(*kv));
|
||||
const float* HWY_RESTRICT mha_kv =
|
||||
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
|
||||
qkv_dim;
|
||||
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
|
||||
}
|
||||
|
||||
// Apply further processing to K.
|
||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -414,7 +410,7 @@ class GemmaAttention {
|
|||
|
||||
// Apply rope and scaling to Q.
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale, q);
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
||||
const size_t start_pos = StartPos(pos, layer_);
|
||||
size_t last_pos = pos;
|
||||
|
|
|
|||
|
|
@ -333,9 +333,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
|||
|
||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos,
|
||||
float* HWY_RESTRICT x_out) {
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
PROFILER_FUNC;
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
|
|
@ -369,8 +368,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
hn::Mul(x1_vec, cos_theta_vec));
|
||||
|
||||
// Store
|
||||
hn::StoreU(xout_0_vec, d, x_out + dim);
|
||||
hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv);
|
||||
hn::StoreU(xout_0_vec, d, x + dim);
|
||||
hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv);
|
||||
}
|
||||
|
||||
// Vectorize computation for remaining dims - same as above, but with LoadN.
|
||||
|
|
@ -399,8 +398,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
|||
hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec));
|
||||
|
||||
// Store
|
||||
hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims);
|
||||
hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims);
|
||||
hn::StoreN(xout_0_vec, d, x + dim, remaining_dims);
|
||||
hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -370,9 +370,8 @@ void TestSigmoid() {
|
|||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos,
|
||||
float* HWY_RESTRICT x_out) {
|
||||
const float mul, float* HWY_RESTRICT x, size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, int pos) {
|
||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
|
|
@ -381,8 +380,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
const float x1 = x[dim + half_dim_qkv];
|
||||
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
||||
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
||||
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
||||
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -413,10 +412,11 @@ void TestRopeAndMulBy() {
|
|||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||
for (int pos = 1; pos < 500; pos++) {
|
||||
// Rope'd Q embeddings
|
||||
ScalarRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||
qexpected.data());
|
||||
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||
qactual.data());
|
||||
hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv);
|
||||
hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv);
|
||||
ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(),
|
||||
pos);
|
||||
RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
|
||||
|
|
@ -424,10 +424,11 @@ void TestRopeAndMulBy() {
|
|||
}
|
||||
|
||||
// Rope'd K embeddings
|
||||
ScalarRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||
kexpected.data());
|
||||
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||
kactual.data());
|
||||
hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv);
|
||||
hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv);
|
||||
ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(),
|
||||
pos);
|
||||
RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos);
|
||||
|
||||
for (int i = 0; i < dim_qkv; ++i) {
|
||||
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
|
||||
|
|
|
|||
Loading…
Reference in New Issue