Fix rope_norm

This commit is contained in:
Oliver Simons 2026-02-04 17:30:49 +01:00
parent 99b7b155a8
commit e4611f3b33
1 changed files with 36 additions and 23 deletions

View File

@ -45,8 +45,13 @@ static __global__ void rope_norm(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne2,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int32_t * pos,
const float freq_scale,
@ -65,17 +70,17 @@ static __global__ void rope_norm(const T * x,
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*nb02 + row_x*nb01 + i0;
const uint32_t i3 = row_dst / (ne1*ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0;
idst += row_indices[channel_x] * set_rows_stride;
idst = i1 * nb11 + i0;
idst += row_indices[i2] * set_rows_stride;
}
const auto & store_coaelsced = [&](float x0, float x1) {
@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -327,8 +332,13 @@ static void rope_norm_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne2,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int nr,
const int32_t * pos,
@ -343,19 +353,19 @@ static void rope_norm_cuda(const T * x,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, nb01, nb02, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
}
@ -622,17 +632,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
}
} else {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, nb01, nb02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, nb01, nb02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}