Rename variables + fix rope_neox

Seems memory layout is shared with Vulkan so we can port fix from
https://github.com/ggml-org/llama.cpp/pull/19299
This commit is contained in:
Oliver Simons 2026-02-04 16:47:04 +01:00
parent 1946e46f4c
commit f7c330aa7e
1 changed files with 71 additions and 52 deletions

View File

@ -45,8 +45,8 @@ static __global__ void rope_norm(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int nb01,
const int nb02,
const int n_dims,
const int32_t * pos,
const float freq_scale,
@ -69,7 +69,7 @@ static __global__ void rope_norm(const T * x,
const int channel_x = row_dst / ne1;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
const int ix = channel_x*nb02 + row_x*nb01 + i0;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
@ -112,8 +112,13 @@ static __global__ void rope_neox(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int ne2,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int32_t * pos,
const float freq_scale,
@ -132,17 +137,18 @@ static __global__ void rope_neox(const T * x,
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const uint32_t i3 = row_dst / (ne1*ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
int idst = row_dst * ne0 + i0 / 2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0 / 2;
idst += row_indices[channel_x] * set_rows_stride;
idst = i1 * nb11 + i0 / 2;
idst += row_indices[i2] * set_rows_stride;
}
if (i0 >= n_dims) {
@ -152,7 +158,7 @@ static __global__ void rope_neox(const T * x,
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -170,7 +176,7 @@ static __global__ void rope_neox(const T * x,
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -185,7 +191,7 @@ static __global__ void rope_multi(
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const int ix = channel_x*nb02 + row_x*nb01 + i0/2;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
@ -240,7 +246,7 @@ static __global__ void rope_multi(
template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -255,7 +261,7 @@ static __global__ void rope_vision(
const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
const int ix = channel_x*nb02 + row_x*nb01 + i0/2;
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
@ -290,8 +296,8 @@ static void rope_norm_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int nb01,
const int nb02,
const int n_dims,
const int nr,
const int32_t * pos,
@ -313,11 +319,11 @@ static void rope_norm_cuda(const T * x,
if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
} else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
}
}
@ -327,8 +333,13 @@ static void rope_neox_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int s1,
const int s2,
const int ne2,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int nr,
const int32_t * pos,
@ -343,25 +354,25 @@ static void rope_neox_cuda(const T * x,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
}
template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
@ -373,18 +384,18 @@ static void rope_multi_cuda(
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
template<bool forward, typename T>
static void rope_vision_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
@ -398,11 +409,11 @@ static void rope_vision_cuda(
if (freq_factors == nullptr) {
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}
@ -443,8 +454,13 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
const size_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
const size_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
const size_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@ -495,28 +511,31 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
@ -524,26 +543,26 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else {