From ba2fd11cdf2ac70093ec4287e34dcfffa004f171 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Fri, 6 Mar 2026 08:32:40 -0800 Subject: [PATCH] cpu: skip redudant ROPE cache updates (#20149) --- ggml/src/ggml-cpu/ops.cpp | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 243f01caf8..2c372f9635 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5803,28 +5803,33 @@ static void ggml_compute_forward_rope_flt( const int32_t * pos = (const int32_t *) src1->data; + int64_t last_i2 = -1; + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!mrope_used) { - const int64_t p = pos[i2]; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - else { - const int64_t p_t = pos[i2]; - const int64_t p_h = pos[i2 + ne2]; - const int64_t p_w = pos[i2 + ne2 * 2]; - const int64_t p_e = pos[i2 + ne2 * 3]; - ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, - freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads - if (ir++ < ir0) continue; + if (ir++ < ir0) continue; // skip rows mapped to other threads if (ir > ir1) break; + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (last_i2 != i2) { + if (!mrope_used) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + + last_i2 = i2; + } + T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);