From f7c330aa7e911e0515cac6746eb3c5260d4f5266 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 4 Feb 2026 16:47:04 +0100 Subject: [PATCH] Rename variables + fix rope_neox Seems memory layout is shared with Vulkan so we can port fix from https://github.com/ggml-org/llama.cpp/pull/19299 --- ggml/src/ggml-cuda/rope.cu | 123 +++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 52 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88ed79111a..8d8c189444 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -45,8 +45,8 @@ static __global__ void rope_norm(const T * x, D * dst, const int ne0, const int ne1, - const int s1, - const int s2, + const int nb01, + const int nb02, const int n_dims, const int32_t * pos, const float freq_scale, @@ -69,7 +69,7 @@ static __global__ void rope_norm(const T * x, const int channel_x = row_dst / ne1; int idst = row_dst * ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; + const int ix = channel_x*nb02 + row_x*nb01 + i0; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. @@ -112,8 +112,13 @@ static __global__ void rope_neox(const T * x, D * dst, const int ne0, const int ne1, - const int s1, - const int s2, + const int ne2, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, const int n_dims, const int32_t * pos, const float freq_scale, @@ -132,17 +137,18 @@ static __global__ void rope_neox(const T * x, const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne1*ne2); + const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1; + const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1; - int idst = row_dst * ne0 + i0 / 2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13; + const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0 / 2; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * nb11 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; } if (i0 >= n_dims) { @@ -152,7 +158,7 @@ static __global__ void rope_neox(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -170,7 +176,7 @@ static __global__ void rope_neox(const T * x, template static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -185,7 +191,7 @@ static __global__ void rope_multi( const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + const int ix = channel_x*nb02 + row_x*nb01 + i0/2; if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; @@ -240,7 +246,7 @@ static __global__ void rope_multi( template static __global__ void rope_vision( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -255,7 +261,7 @@ static __global__ void rope_vision( const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + const int ix = channel_x*nb02 + row_x*nb01 + i0/2; const int sect_dims = sections.v[0] + sections.v[1]; const int sec_w = sections.v[1] + sections.v[0]; @@ -290,8 +296,8 @@ static void rope_norm_cuda(const T * x, D * dst, const int ne0, const int ne1, - const int s1, - const int s2, + const int nb01, + const int nb02, const int n_dims, const int nr, const int32_t * pos, @@ -313,11 +319,11 @@ static void rope_norm_cuda(const T * x, if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, + x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, + x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } @@ -327,8 +333,13 @@ static void rope_neox_cuda(const T * x, D * dst, const int ne0, const int ne1, - const int s1, - const int s2, + const int ne2, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, const int n_dims, const int nr, const int32_t * pos, @@ -343,25 +354,25 @@ static void rope_neox_cuda(const T * x, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } template static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int nr, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); @@ -373,18 +384,18 @@ static void rope_multi_cuda( if (freq_factors == nullptr) { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } template static void rope_vision_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int nr, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); @@ -398,11 +409,11 @@ static void rope_vision_cuda( if (freq_factors == nullptr) { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } } @@ -443,8 +454,13 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, const int64_t ne02 = src0->ne[2]; // num heads const int64_t nr = ggml_nrows(src0); - const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); - const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t nb01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t nb03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t nb11 = dst->nb[1] / ggml_type_size(dst->type); + const size_t nb12 = dst->nb[2] / ggml_type_size(dst->type); + const size_t nb13 = dst->nb[3] / ggml_type_size(dst->type); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -495,28 +511,31 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, // compute if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, + nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, + nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, + nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { rope_multi_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { rope_multi_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); @@ -524,26 +543,26 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { rope_vision_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { rope_vision_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, row_indices, set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, row_indices, set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, row_indices, set_rows_stride, stream); } else {