diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index d7b3a79..1cfe7a5 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -213,19 +213,18 @@ class GemmaAttention { } template - HWY_INLINE void PositionalEncodingQK(const U* qk, size_t pos, size_t layer, - const float mul, U* qk_out) { + HWY_INLINE void PositionalEncodingQK(U* qk, size_t pos, size_t layer, + const float mul) { // qk is either q or k, so qkv_dim is the length we operate on. const size_t qkv_dim = layer_config_.qkv_dim; const float* inv_timescale = activations_.inv_timescale.Const(); // PostQKType::Rope (void)layer; if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) { - hwy::CopyBytes(qk, qk_out, qkv_dim * sizeof(*qk)); - Rope(qk_out, qkv_dim / 2, inv_timescale, pos); - MulByConst(mul, qk_out, qkv_dim); + Rope(qk, qkv_dim / 2, inv_timescale, pos); + if (mul != 1.0f) MulByConst(mul, qk, qkv_dim); } else { - RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos, qk_out); + RopeAndMulBy(mul, qk, qkv_dim, inv_timescale, pos); } } @@ -315,19 +314,16 @@ class GemmaAttention { head * qkv_dim * 2; KVCache& kv_cache = kv_caches_[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - const float* HWY_RESTRICT mha_kv = - activations_.q.Batch(interleaved_idx) + head * q_stride_ + - qkv_dim; - - // Copy from `q` if MHA, or apply in-place. - PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, - kv); - - // If MHA, also copy V into KVCache. + // If MHA, copy computed K and V into KVCache. if (is_mha_) { - hwy::CopyBytes(mha_kv + qkv_dim, kv + qkv_dim, - qkv_dim * sizeof(*kv)); + const float* HWY_RESTRICT mha_kv = + activations_.q.Batch(interleaved_idx) + head * q_stride_ + + qkv_dim; + hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv)); } + + // Apply further processing to K. + PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); }); } @@ -414,7 +410,7 @@ class GemmaAttention { // Apply rope and scaling to Q. const size_t pos = queries_pos_[query_idx] + batch_idx; - PositionalEncodingQK(q, pos, layer_, query_scale, q); + PositionalEncodingQK(q, pos, layer_, query_scale); const size_t start_pos = StartPos(pos, layer_); size_t last_pos = pos; diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 3da48e1..9919abf 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -333,9 +333,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( // `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( - const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, int pos, - float* HWY_RESTRICT x_out) { + const float mul, float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, int pos) { PROFILER_FUNC; HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; @@ -369,8 +368,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( hn::Mul(x1_vec, cos_theta_vec)); // Store - hn::StoreU(xout_0_vec, d, x_out + dim); - hn::StoreU(xout_1_vec, d, x_out + dim + half_dim_qkv); + hn::StoreU(xout_0_vec, d, x + dim); + hn::StoreU(xout_1_vec, d, x + dim + half_dim_qkv); } // Vectorize computation for remaining dims - same as above, but with LoadN. @@ -399,8 +398,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( hn::MulAdd(x0_vec, sin_theta_vec, hn::Mul(x1_vec, cos_theta_vec)); // Store - hn::StoreN(xout_0_vec, d, x_out + dim, remaining_dims); - hn::StoreN(xout_1_vec, d, x_out + dim + half_dim_qkv, remaining_dims); + hn::StoreN(xout_0_vec, d, x + dim, remaining_dims); + hn::StoreN(xout_1_vec, d, x + dim + half_dim_qkv, remaining_dims); } } diff --git a/ops/ops_test.cc b/ops/ops_test.cc index ddf5ec6..8e57373 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -370,9 +370,8 @@ void TestSigmoid() { } static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( - const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, int pos, - float* HWY_RESTRICT x_out) { + const float mul, float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, int pos) { HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { @@ -381,8 +380,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( const float sin_val = sinf(theta); const float x0 = x[dim]; const float x1 = x[dim + half_dim_qkv]; - x_out[dim] = mul * (x0 * cos_val - x1 * sin_val); - x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); + x[dim] = mul * (x0 * cos_val - x1 * sin_val); + x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); } } @@ -413,10 +412,11 @@ void TestRopeAndMulBy() { // Assert VectorizedRope computation is same as regular rope at different pos. for (int pos = 1; pos < 500; pos++) { // Rope'd Q embeddings - ScalarRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - qexpected.data()); - RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - qactual.data()); + hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv); + hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv); + ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(), + pos); + RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos); for (int i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(qactual[i], qexpected[i], 1e-4) @@ -424,10 +424,11 @@ void TestRopeAndMulBy() { } // Rope'd K embeddings - ScalarRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - kexpected.data()); - RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - kactual.data()); + hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv); + hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv); + ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(), + pos); + RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos); for (int i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)