Documenting the RoPE implementation.

PiperOrigin-RevId: 636175297
This commit is contained in:
Apoorv Reddy 2024-05-22 08:25:31 -07:00 committed by Copybara-Service
parent c0643577c3
commit 1aaf3b3aae
1 changed files with 22 additions and 0 deletions

View File

@ -685,6 +685,28 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
}
}
/* RoPE as in Rotary Position Embeddings from the RoFormer paper
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
as a function of their absolute position using the rotation matrix R before
the self-attention operation. R is a d x d matrix.
R = cos(m*theta_1) -sin(m*theta_1) ... 0 0
sin(m*theta_1) cos(m*theta_1)
0 0 ... 0 0
0 0 ... 0 0
...
0 0 ... cos(m*theta_{d/2}) sin(m*theta_{d/2})
0 0 ... sin(m*theta_{d/2}) cos(m*theta_{d/2})
Here theta_i = 10000^(-2(i-1)/d), where d is the dimension of the vector and
i is the ith index of the vector.
Applying the rotation matrix R to a vector v is equivalent to rotating every
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
m*theta_i. However in the Gemma implementation we choose to rotate
the pairs of dimensions v_{i} and v_{i + d//2} instead.
*/
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
size_t dim_qkv, size_t pos) {
HWY_DASSERT(dim_qkv % 2 == 0);