diff --git a/gemma/ops.h b/gemma/ops.h index 8cac3f9..ed12ef4 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -685,6 +685,28 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( } } +/* RoPE as in Rotary Position Embeddings from the RoFormer paper + (https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated + as a function of their absolute position using the rotation matrix R before + the self-attention operation. R is a d x d matrix. + + R = cos(m*theta_1) -sin(m*theta_1) ... 0 0 + sin(m*theta_1) cos(m*theta_1) + 0 0 ... 0 0 + 0 0 ... 0 0 + ... + 0 0 ... cos(m*theta_{d/2}) sin(m*theta_{d/2}) + 0 0 ... sin(m*theta_{d/2}) cos(m*theta_{d/2}) + + Here theta_i = 10000^(-2(i-1)/d), where d is the dimension of the vector and + i is the ith index of the vector. + + Applying the rotation matrix R to a vector v is equivalent to rotating every + consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle + m*theta_i. However in the Gemma implementation we choose to rotate + the pairs of dimensions v_{i} and v_{i + d//2} instead. +*/ + static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, size_t dim_qkv, size_t pos) { HWY_DASSERT(dim_qkv % 2 == 0);