diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ce3c85e758..2a2f7f4f11 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4576,7 +4576,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index d62696bcfa..6802b1fc95 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -5,8 +5,9 @@ #include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(constant_id = 1) const uint TOKENS_PER_WG = 16; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; @@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants { }; void main() { - const uint global_thread_id = gl_GlobalInvocationID.x; - const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_GlobalInvocationID.x; + const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y; const uint i3 = gl_WorkGroupID.z; - if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) { + if (i1 >= nr || i2 >= n_t || i3 >= n_s) { return; } - const uint i1 = global_thread_id; const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4); const uint src1_base = i1 * (nb11 / 4); - const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; float sum = 0.0; - [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { - const uint src0_idx = src0_base + i0; - const uint src1_idx = src1_base + i0; - sum += src0[src0_idx] * src1[src1_idx]; + + if (nc == 4) { + sum = dot( + vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]), + vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3]) + ); + } else { + [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; }