diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 8d8c189444..576992da91 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -174,12 +174,29 @@ static __global__ void rope_neox(const T * x, dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta); } -template -static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, - const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { - const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); +template +static __global__ void rope_multi(const T * x, + T * dst, + const int ne0, + const int ne1, + const int ne2, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; @@ -187,11 +204,12 @@ static __global__ void rope_multi( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne1*ne2); + const uint32_t i2 = (row_dst - i3 * ne1 * ne2) / ne1; + const uint32_t i1 = row_dst - i3 * ne1 * ne2 - i2 * ne1; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*nb02 + row_x*nb01 + i0/2; + int idst = i0 / 2 + i1 * nb11 + i2 * nb12 + i3 * nb13; + const int ix = i0 / 2 + + i1 * nb01 + i2 * nb02 + i3 * nb03; if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; @@ -206,27 +224,24 @@ static __global__ void rope_multi( float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne2 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne2 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne2 * 3] * powf(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne2 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne2 * 3] * powf(theta_scale, i0 / 2.0f); } } @@ -370,26 +385,45 @@ static void rope_neox_cuda(const T * x, } } -template -static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int nb01, const int nb02, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { +template +static void rope_multi_cuda(const T * x, + T * dst, + const int ne0, + const int ne1, + const int ne2, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne0 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_multi<<>>( - x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); + x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( - x, dst, ne0, ne1, ne2, nb01, nb02, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); + x, dst, ne0, ne1, ne2, nb01, nb02, nb03, nb11, nb12, nb13, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, is_imrope); } } @@ -530,13 +564,13 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11, + nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, nb01, nb02, nb03, nb11, + nb12, nb13, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); }