vulkan: optimize ssm_scan (#18630)

* vulkan: optimize ssm_scan

* fix warp vs subgroup naming
This commit is contained in:
Jeff Bolz 2026-01-08 08:16:54 -06:00 committed by GitHub
parent 55abc39355
commit cb14b06995
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 69 deletions

View File

@ -570,6 +570,7 @@ struct vk_device_struct {
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
bool subgroup_basic;
bool subgroup_arithmetic;
bool subgroup_shuffle;
bool subgroup_ballot;
@ -4301,8 +4302,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
@ -4638,6 +4639,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
#ifdef __APPLE__
@ -9870,8 +9873,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
std::array<uint32_t, 3> elements;
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t d_state = src0->ne[0];
uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
@ -14777,11 +14781,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
const uint32_t SPLIT_H = 16;
size_t shmem_size = d_state * sizeof(float);
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
if (!device->subgroup_basic) {
return false;
}

View File

@ -1,6 +1,7 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
@ -9,7 +10,8 @@
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 2) const uint SPLIT_H = 16;
const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@ -41,22 +43,28 @@ float softplus(float x) {
}
}
shared float stateC[SPLIT_H * D_STATE];
#if !USE_SUBGROUP_ADD
shared float temp[D_STATE];
#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint subgroup = gl_SubgroupID;
const uint lane = gl_SubgroupInvocationID;
const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
const uint head_idx = subgroup_idx / d_head;
const uint head_off = (subgroup_idx % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
@ -65,76 +73,52 @@ void main() {
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
float state[SPLIT_H];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
state[j] = s0[s0_base_idx + j * D_STATE + tid];
float state[c_factor];
[[unroll]] for (uint j = 0; j < c_factor; j++) {
state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
}
float a = A[A_base_idx];
for (uint i = 0; i < n_tok; i++) {
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
const float dA = exp(dt_soft_plus * A[A_base_idx]);
const float B_val = B[B_base_idx + i * stride_B + tid];
const float C_val = C[C_base_idx + i * stride_C + tid];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
float state_sum = 0.0f;
const float dA = exp(dt_soft_plus * a);
const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
[[unroll]] for (uint j = 0; j < c_factor; j++) {
float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
stateC[j * D_STATE + tid] = state[j] * C_val;
state_sum += state[j] * C_val;
}
#if USE_SUBGROUP_ADD
state_sum = subgroupAdd(state_sum);
#else
temp[tid] = state_sum;
barrier();
[[unroll]]
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + w];
}
[[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
if (lane < s) {
temp[tid] += temp[tid + s];
}
barrier();
}
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
const uint max_idx = SUBGROUP_SIZE - 1 +
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
if (idx < SPLIT_H * D_STATE ||
max_idx < SPLIT_H * D_STATE) {
float sc;
#if USE_SUBGROUP_ADD
sc = stateC[idx];
sc = subgroupAdd(sc);
#else
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
if (tid % SUBGROUP_SIZE == 0) {
sc = stateC[idx];
}
// get the value from lane 0
state_sum = temp[subgroup * SUBGROUP_SIZE];
barrier();
#endif
if (tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = sc;
}
}
if (lane == 0) {
d[y_base_idx + i * stride_y] = state_sum;
}
barrier();
}
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
d[s_base_idx + j * D_STATE + tid] = state[j];
// write back the state
[[unroll]]
for (int j = 0; j < c_factor; j++) {
d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
}
}