diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4d3c085f67..7e17f4945d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -570,6 +570,7 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_basic; bool subgroup_arithmetic; bool subgroup_shuffle; bool subgroup_ballot; @@ -4301,8 +4302,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); } else { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); @@ -4638,6 +4639,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic); device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); #ifdef __APPLE__ @@ -9870,8 +9873,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, std::array elements; - const int splitH = 16; - const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); + const uint32_t d_state = src0->ne[0]; + uint32_t num_subgroups = d_state / ctx->device->subgroup_size; + const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups); const uint32_t num_workgroups_y = n_seq; elements = { num_workgroups_x, num_workgroups_y, 1 }; @@ -14777,11 +14781,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - const uint32_t SPLIT_H = 16; + size_t shmem_size = d_state * sizeof(float); - size_t stateC_size = SPLIT_H * d_state * sizeof(float); + if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } - if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { + if (!device->subgroup_basic) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp index 8f67be9799..c7416206db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable #if USE_SUBGROUP_ADD #extension GL_KHR_shader_subgroup_arithmetic : enable #endif @@ -9,7 +10,8 @@ layout(constant_id = 0) const uint D_STATE = 128; layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; -layout(constant_id = 2) const uint SPLIT_H = 16; + +const uint32_t c_factor = D_STATE / SUBGROUP_SIZE; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -41,22 +43,28 @@ float softplus(float x) { } } -shared float stateC[SPLIT_H * D_STATE]; +#if !USE_SUBGROUP_ADD +shared float temp[D_STATE]; +#endif void main() { - const uint tid = gl_LocalInvocationID.x; - const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; - const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; - const uint seq_idx = gl_WorkGroupID.y; + const uint subgroup = gl_SubgroupID; + const uint lane = gl_SubgroupInvocationID; + const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane; + const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup; + + const uint head_idx = subgroup_idx / d_head; + const uint head_off = (subgroup_idx % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; - const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; + const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4; const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; const uint A_base_idx = (head_idx * nb31) / 4; const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; - const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx; const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; const uint stride_x = nb12 / 4; @@ -65,76 +73,52 @@ void main() { const uint stride_C = nb52 / 4; const uint stride_y = n_head * d_head; - float state[SPLIT_H]; - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - state[j] = s0[s0_base_idx + j * D_STATE + tid]; + float state[c_factor]; + + [[unroll]] for (uint j = 0; j < c_factor; j++) { + state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane]; } + float a = A[A_base_idx]; + for (uint i = 0; i < n_tok; i++) { - const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); - const float dA = exp(dt_soft_plus * A[A_base_idx]); - - const float B_val = B[B_base_idx + i * stride_B + tid]; - const float C_val = C[C_base_idx + i * stride_C + tid]; - - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; + float state_sum = 0.0f; + const float dA = exp(dt_soft_plus * a); + const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus; + [[unroll]] for (uint j = 0; j < c_factor; j++) { + float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane]; + float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane]; state[j] = (state[j] * dA) + (B_val * x_dt); - - stateC[j * D_STATE + tid] = state[j] * C_val; + state_sum += state[j] * C_val; } +#if USE_SUBGROUP_ADD + state_sum = subgroupAdd(state_sum); +#else + temp[tid] = state_sum; barrier(); - [[unroll]] - for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { - [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { - const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); - if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { - stateC[k] += stateC[k + w]; - } + [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) { + if (lane < s) { + temp[tid] += temp[tid + s]; } barrier(); } - - [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { - const uint idx = (tid % SUBGROUP_SIZE) + - D_STATE * (tid / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - const uint max_idx = SUBGROUP_SIZE - 1 + - D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - - if (idx < SPLIT_H * D_STATE || - max_idx < SPLIT_H * D_STATE) { - float sc; -#if USE_SUBGROUP_ADD - sc = stateC[idx]; - sc = subgroupAdd(sc); -#else - [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (idx + offset < SPLIT_H * D_STATE) { - stateC[idx] += stateC[idx + offset]; - } - barrier(); - } - if (tid % SUBGROUP_SIZE == 0) { - sc = stateC[idx]; - } + // get the value from lane 0 + state_sum = temp[subgroup * SUBGROUP_SIZE]; + barrier(); #endif - if (tid % SUBGROUP_SIZE == 0) { - const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); - d[y_base_idx + i * stride_y + k] = sc; - } - } + if (lane == 0) { + d[y_base_idx + i * stride_y] = state_sum; } - - barrier(); } - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - d[s_base_idx + j * D_STATE + tid] = state[j]; + // write back the state + [[unroll]] + for (int j = 0; j < c_factor; j++) { + d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j]; } }