Fix rope_vision

This commit is contained in:
Oliver Simons 2026-02-04 17:15:24 +01:00
parent 5f08773a4d
commit 99b7b155a8
1 changed files with 64 additions and 30 deletions

View File

@ -259,11 +259,27 @@ static __global__ void rope_multi(const T * x,
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
template <bool forward, bool has_ff, typename T>
static __global__ void rope_vision(const T * x,
T * dst,
const int ne0,
const int ne1,
const int ne2,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
@ -272,24 +288,24 @@ static __global__ void rope_vision(
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const uint32_t i3 = row_dst / (ne1 * ne2);
const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1;
const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*nb02 + row_x*nb01 + i0/2;
int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 / 2 + +i1 * nb01 + i2 * nb02 + i3 * nb03;
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[channel_x]*powf(theta_scale, p);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2] * powf(theta_scale, p);
} else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
theta_base = pos[i2 + ne2] * powf(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -427,11 +443,29 @@ static void rope_multi_cuda(const T * x,
}
}
template<bool forward, typename T>
static void rope_vision_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
template <bool forward, typename T>
static void rope_vision_cuda(const T * x,
T * dst,
const int ne0,
const int ne1,
const int ne2,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const mrope_sections sections,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@ -443,12 +477,12 @@ static void rope_vision_cuda(
if (freq_factors == nullptr) {
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections);
} else {
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections);
}
}
@ -576,13 +610,13 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11,
nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11,
nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}