diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cb7fa2c9cb..af57685a37 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1263,25 +1263,30 @@ struct vk_op_diag_mask_push_constants { struct vk_op_rope_push_constants { uint32_t rope_mode; - uint32_t ncols; uint32_t nrows; uint32_t n_dims; float freq_scale; - uint32_t p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint32_t has_ff; - uint32_t ne02; - uint32_t s1; - uint32_t s2; int32_t sections[4]; uint32_t is_imrope; uint32_t is_back; uint32_t set_rows_stride; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; }; +static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); // For fused rms_norm+mul+rope(+view+set_rows) struct vk_op_rms_norm_mul_rope_push_constants { @@ -10405,12 +10410,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type); + + uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type); + uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type); + uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - has_ff, (uint32_t)src0->ne[2], nb01, nb02, + (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + + (uint32_t)src0->ne[0], + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + nb01, nb02, nb03, + nb11, nb12, nb13, }; return rope; @@ -14798,6 +14813,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: + return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ROPE_BACK: case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 9d6d366542..55b89f19a7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -112,12 +112,11 @@ void rms_norm(uint num_iters) { #if RMS_NORM_ROPE_FUSION barrier(); rope_params rp = p.rope; - uint rope_row = (samp*nchannels + channel)*nrows + row; for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { - rope_neox(t, rope_row, rp); + rope_neox(t, row, channel, samp, rp); } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { - rope_norm(t, rope_row, rp); + rope_norm(t, row, channel, samp, rp); } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index aacec98469..2e53459909 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) { return 1.0f - min(1.0f, max(0.0f, y)); } -uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) { #if RMS_NORM_ROPE_FUSION // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out sin_theta = sin(theta) * mscale; } -void rope_norm(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0; - const uint ix = rope_a_coord(i0, i01, i02, p); + uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_neox(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { } -void rope_multi(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (p.is_imrope != 0) { if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } else { if (sector < p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } @@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_vision(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - const uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; const int sec_w = p.sections[1] + p.sections[0]; @@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (sector < p.sections[0]) { const uint p0 = sector; - theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0); } else if (sector >= p.sections[0] && sector < sec_w) { const uint p0 = sector - p.sections[0]; - theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0); } const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index f7587468a8..1528fbeeae 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_multi(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_multi(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index acb8ed7815..ad0896095d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_neox(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_neox(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 0033cdb224..11220817df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_norm(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_norm(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 939cf3c51c..ec6ceaca9b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -5,24 +5,29 @@ struct rope_params { uint rope_mode; - uint ncols; uint nrows; uint n_dims; float freq_scale; - uint p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint has_ff; - uint ne02; - uint nb01; - uint nb02; int sections[4]; uint is_imrope; uint is_back; uint set_rows_stride; + + uint ne00; + uint ne01; + uint ne02; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; }; #endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d93800b5e7..ca71efb2f5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_vision(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_vision(i0, i1, i2, i3, pc); }