From e4611f3b33eff3402810c79741d3ea44ec41a107 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 4 Feb 2026 17:30:49 +0100 Subject: [PATCH] Fix rope_norm --- ggml/src/ggml-cuda/rope.cu | 59 +++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 0972146270..f16cb1df11 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -45,8 +45,13 @@ static __global__ void rope_norm(const T * x, D * dst, const int ne0, const int ne1, + const int ne2, const int nb01, const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, const int n_dims, const int32_t * pos, const float freq_scale, @@ -65,17 +70,17 @@ static __global__ void rope_norm(const T * x, const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - int idst = row_dst * ne0 + i0; - const int ix = channel_x*nb02 + row_x*nb01 + i0; + const uint32_t i3 = row_dst / (ne1*ne2); + const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1; + const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1; + int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13; + const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * nb11 + i0; + idst += row_indices[i2] * set_rows_stride; } const auto & store_coaelsced = [&](float x0, float x1) { @@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -327,8 +332,13 @@ static void rope_norm_cuda(const T * x, D * dst, const int ne0, const int ne1, + const int ne2, const int nb01, const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, const int n_dims, const int nr, const int32_t * pos, @@ -343,19 +353,19 @@ static void rope_norm_cuda(const T * x, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } @@ -622,17 +632,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, + nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, + nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, + nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); }