vulkan: fix SSM_CONV PP scaling with large ubatch sizes (#20379)

* vulkan: optimize SSM_CONV workgroup dispatch for large ubatch

Tile tokens into 2D workgroups (32x16) to reduce workgroup launch
overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common
d_conv size). Fixes PP performance degradation with ubatch > 512.

Ref: ggml-org/llama.cpp#18725

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* vulkan: remove unused shared memory declaration in SSM_CONV

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Progeny Alpha <ProgenyAlpha@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
ProgenyAlpha 2026-03-12 05:03:18 -04:00 committed by GitHub
parent de190154c8
commit 40c550d4f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 11 deletions

View File

@ -4576,7 +4576,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

View File

@ -5,8 +5,9 @@
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants {
};
void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
const uint i3 = gl_WorkGroupID.z;
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
if (nc == 4) {
sum = dot(
vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
);
} else {
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
sum += src0[src0_base + i0] * src1[src1_base + i0];
}
}
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
dst[dst_idx] = sum;
}