vulkan: fix SSM_CONV PP scaling with large ubatch sizes (#20379)
* vulkan: optimize SSM_CONV workgroup dispatch for large ubatch Tile tokens into 2D workgroups (32x16) to reduce workgroup launch overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common d_conv size). Fixes PP performance degradation with ubatch > 512. Ref: ggml-org/llama.cpp#18725 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * vulkan: remove unused shared memory declaration in SSM_CONV Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Progeny Alpha <ProgenyAlpha@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
de190154c8
commit
40c550d4f6
|
|
@ -4576,7 +4576,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,9 @@
|
|||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float src0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float src1[]; };
|
||||
|
|
@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants {
|
|||
};
|
||||
|
||||
void main() {
|
||||
const uint global_thread_id = gl_GlobalInvocationID.x;
|
||||
const uint i2 = gl_WorkGroupID.y;
|
||||
const uint i1 = gl_GlobalInvocationID.x;
|
||||
const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
|
||||
const uint i3 = gl_WorkGroupID.z;
|
||||
|
||||
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
|
||||
if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i1 = global_thread_id;
|
||||
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
|
||||
const uint src1_base = i1 * (nb11 / 4);
|
||||
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||
|
||||
float sum = 0.0;
|
||||
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
|
||||
const uint src0_idx = src0_base + i0;
|
||||
const uint src1_idx = src1_base + i0;
|
||||
sum += src0[src0_idx] * src1[src1_idx];
|
||||
|
||||
if (nc == 4) {
|
||||
sum = dot(
|
||||
vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
|
||||
vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
|
||||
);
|
||||
} else {
|
||||
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
|
||||
sum += src0[src0_base + i0] * src1[src1_base + i0];
|
||||
}
|
||||
}
|
||||
|
||||
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue