Fix rope_norm

This commit is contained in:
Oliver Simons 2026-02-04 17:30:49 +01:00
parent 99b7b155a8
commit e4611f3b33
1 changed files with 36 additions and 23 deletions

View File

@ -45,8 +45,13 @@ static __global__ void rope_norm(const T * x,
D * dst, D * dst,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims, const int n_dims,
const int32_t * pos, const int32_t * pos,
const float freq_scale, const float freq_scale,
@ -65,17 +70,17 @@ static __global__ void rope_norm(const T * x,
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1; const uint32_t i3 = row_dst / (ne1*ne2);
const int channel_x = row_dst / ne1; const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*nb02 + row_x*nb01 + i0;
int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
// Fusion optimization: ROPE + VIEW + SET_ROWS. // Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) { if (set_rows_stride != 0) {
idst = row_x * ne0 + i0; idst = i1 * nb11 + i0;
idst += row_indices[channel_x] * set_rows_stride; idst += row_indices[i2] * set_rows_stride;
} }
const auto & store_coaelsced = [&](float x0, float x1) { const auto & store_coaelsced = [&](float x0, float x1) {
@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
return; return;
} }
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -327,8 +332,13 @@ static void rope_norm_cuda(const T * x,
D * dst, D * dst,
const int ne0, const int ne0,
const int ne1, const int ne1,
const int ne2,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims, const int n_dims,
const int nr, const int nr,
const int32_t * pos, const int32_t * pos,
@ -343,19 +353,19 @@ static void rope_norm_cuda(const T * x,
cudaStream_t stream) { cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1); const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
freq_factors, row_indices, set_rows_stride); corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else { } else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
freq_factors, row_indices, set_rows_stride); corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} }
} }
@ -622,17 +632,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
} }
} else { } else {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims, rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr, rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }