This commit is contained in:
Oliver Simons 2026-02-06 07:56:36 -06:00 committed by GitHub
commit f8dccbd1f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 242 additions and 142 deletions

View File

@ -43,10 +43,15 @@ static __device__ void rope_yarn(
template <bool forward, bool has_ff, typename T, typename D> template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_norm(const T * x, static __global__ void rope_norm(const T * x,
D * dst, D * dst,
const int ne0, const int ne00,
const int ne1, const int ne01,
const int s1, const int ne02,
const int s2, const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims, const int n_dims,
const int32_t * pos, const int32_t * pos,
const float freq_scale, const float freq_scale,
@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x,
const int set_rows_stride) { const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne00) {
return; return;
} }
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1; const uint32_t i3 = row_dst / (ne01 * ne02);
const int channel_x = row_dst / ne1; const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
int idst = i0 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = i0 + i1 * nb01 + i2 * nb02 + i3 * nb03;
// Fusion optimization: ROPE + VIEW + SET_ROWS. // Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) { if (set_rows_stride != 0) {
idst = row_x * ne0 + i0; idst = i1 * nb11 + i0;
idst += row_indices[channel_x] * set_rows_stride; idst += row_indices[i2] * set_rows_stride;
} }
const auto & store_coaelsced = [&](float x0, float x1) { const auto & store_coaelsced = [&](float x0, float x1) {
@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
return; return;
} }
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x,
template <bool forward, bool has_ff, typename T, typename D> template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_neox(const T * x, static __global__ void rope_neox(const T * x,
D * dst, D * dst,
const int ne0, const int ne00,
const int ne1, const int ne01,
const int s1, const int ne02,
const int s2, const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims, const int n_dims,
const int32_t * pos, const int32_t * pos,
const float freq_scale, const float freq_scale,
@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x,
const int set_rows_stride) { const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne00) {
return; return;
} }
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1; const uint32_t i3 = row_dst / (ne01 * ne02);
const int channel_x = row_dst / ne1; const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = row_dst * ne0 + i0 / 2; int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = channel_x*s2 + row_x*s1 + i0/2; const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
// Fusion optimization: ROPE + VIEW + SET_ROWS. // Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) { if (set_rows_stride != 0) {
idst = row_x * ne0 + i0 / 2; idst = i1 * nb11 + i0 / 2;
idst += row_indices[channel_x] * set_rows_stride; idst += row_indices[i2] * set_rows_stride;
} }
if (i0 >= n_dims) { if (i0 >= n_dims) {
@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x,
return; return;
} }
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x,
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta); dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
} }
template<bool forward, bool has_ff, typename T> template <bool forward, bool has_ff, typename T>
static __global__ void rope_multi( static __global__ void rope_multi(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, T * dst,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const int ne00,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { const int ne01,
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int ne02,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const mrope_sections sections,
const bool is_imrope) {
const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne00) {
return; return;
} }
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1; const uint32_t i3 = row_dst / (ne01 * ne02);
const int channel_x = row_dst / ne1; const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int idst = row_dst*ne0 + i0/2; int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = channel_x*s2 + row_x*s1 + i0/2; const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03;
if (i0 >= n_dims) { if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
@ -200,27 +229,24 @@ static __global__ void rope_multi(
float theta_base = 0.0; float theta_base = 0.0;
if (is_imrope) { if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
} else { } else {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
} }
} else { } else {
if (sector < sections.v[0]) { if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
} } else if (sector >= sections.v[0] && sector < sec_w) {
else if (sector >= sections.v[0] && sector < sec_w) { theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
} theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
else if (sector >= sec_w && sector < sec_w + sections.v[2]) { } else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
} }
} }
@ -238,37 +264,53 @@ static __global__ void rope_multi(
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
} }
template<bool forward, bool has_ff, typename T> template <bool forward, bool has_ff, typename T>
static __global__ void rope_vision( static __global__ void rope_vision(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, T * dst,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, const int ne00,
const float theta_scale, const float * freq_factors, const mrope_sections sections) { const int ne01,
const int ne02,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne00) {
return; return;
} }
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1; const uint32_t i3 = row_dst / (ne01 * ne02);
const int channel_x = row_dst / ne1; const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int idst = row_dst*ne0 + i0/2; int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13;
const int ix = channel_x*s2 + row_x*s1 + i0/2; const int ix = i0 / 2 + +i1 * nb01 + i2 * nb02 + i3 * nb03;
const int sect_dims = sections.v[0] + sections.v[1]; const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0]; const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims; const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0; float theta_base = 0.0;
if (sector < sections.v[0]) { if (sector < sections.v[0]) {
const int p = sector; const int p = sector;
theta_base = pos[channel_x]*powf(theta_scale, p); theta_base = pos[i2] * powf(theta_scale, p);
} } else if (sector >= sections.v[0] && sector < sec_w) {
else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0]; const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2]*powf(theta_scale, p); theta_base = pos[i2 + ne02] * powf(theta_scale, p);
} }
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -288,10 +330,15 @@ static __global__ void rope_vision(
template <bool forward, typename T, typename D> template <bool forward, typename T, typename D>
static void rope_norm_cuda(const T * x, static void rope_norm_cuda(const T * x,
D * dst, D * dst,
const int ne0, const int ne00,
const int ne1, const int ne01,
const int s1, const int ne02,
const int s2, const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims, const int n_dims,
const int nr, const int nr,
const int32_t * pos, const int32_t * pos,
@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x,
const int64_t * row_indices, const int64_t * row_indices,
const int set_rows_stride, const int set_rows_stride,
cudaStream_t stream) { cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1); const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
freq_factors, row_indices, set_rows_stride); attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else { } else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
freq_factors, row_indices, set_rows_stride); attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} }
} }
template <bool forward, typename T, typename D> template <bool forward, typename T, typename D>
static void rope_neox_cuda(const T * x, static void rope_neox_cuda(const T * x,
D * dst, D * dst,
const int ne0, const int ne00,
const int ne1, const int ne01,
const int s1, const int ne02,
const int s2, const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims, const int n_dims,
const int nr, const int nr,
const int32_t * pos, const int32_t * pos,
@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x,
const int64_t * row_indices, const int64_t * row_indices,
const int set_rows_stride, const int set_rows_stride,
cudaStream_t stream) { cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1); const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>( rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
freq_factors, row_indices, set_rows_stride); attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else { } else {
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
freq_factors, row_indices, set_rows_stride); attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} }
} }
template<bool forward, typename T> template <bool forward, typename T>
static void rope_multi_cuda( static void rope_multi_cuda(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, T * dst,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const int ne00,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { const int ne01,
GGML_ASSERT(ne0 % 2 == 0); const int ne02,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const mrope_sections sections,
const bool is_imrope,
cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1); const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>( rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else { } else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>( rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} }
} }
template<bool forward, typename T> template <bool forward, typename T>
static void rope_vision_cuda( static void rope_vision_cuda(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, T * dst,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const int ne00,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { const int ne01,
GGML_ASSERT(ne0 % 2 == 0); const int ne02,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const mrope_sections sections,
cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1); const dim3 block_nums(nr, n_blocks_x, 1);
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@ -398,11 +487,11 @@ static void rope_vision_cuda(
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>( rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections); attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else { } else {
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>( rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne00, ne01, ne02, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections); attn_factor, corr_dims, theta_scale, freq_factors, sections);
} }
} }
@ -443,8 +532,13 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
const int64_t ne02 = src0->ne[2]; // num heads const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0); const int64_t nr = ggml_nrows(src0);
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); const size_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
const size_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
const size_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
const size_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
//const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
// compute // compute
if (is_neox) { if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else if (is_mrope && !is_vision) { } else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>( rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11,
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>( rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11,
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); corr_dims, freq_factors, sections, is_imrope, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else if (is_vision) { } else if (is_vision) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda<forward>( rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11,
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda<forward>( rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11,
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); corr_dims, freq_factors, sections, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else { } else {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, nb03, nb11, nb12, nb13, n_dims, nr, pos, freq_scale, freq_base,
freq_factors, row_indices, set_rows_stride, stream); ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }