From bfd1f453cb936a83f1a64168746029f2aa509fbb Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Fri, 10 Apr 2026 10:52:38 -0700 Subject: [PATCH] ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (#21669) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 13 +++++-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 34 +++++++++---------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3b894a9b9c..e979783f02 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); #ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now + // Accept f16 subgroup matrix configurations (square or non-square). + // NVIDIA GPUs typically report square configs (e.g. 16x16x16), + // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). + // The shaders are already parameterized to handle any M/N/K dimensions. bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; @@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { break; } + // Head dimensions must be divisible by subgroup matrix dimensions + if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || + src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { + break; + } // Head dimensions must fit in workgroup memory with minimum tile sizes size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 8b76cecba9..aa2d2e54db 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #endif for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); + var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); - var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); + var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); + var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); + var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif var t: u32 = 1u; for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { let h0 = t * SG_MAT_K; - var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); + var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); + var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); #else - var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); + var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q0; k_cur = k0; let h1 = (t + 1u) * SG_MAT_K; - var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); + var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); + var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); #else - var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); + var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q1g; @@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // handle odd tail if (t < HEAD_DIM_QK / SG_MAT_K) { let h = t * SG_MAT_K; - var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); + var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); + var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); #else - var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); + var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = qn; @@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, head_dim_block < HEAD_DIM_V; head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, head_dim_block, false, @@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; - var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &inter_shmem, p_offset, false, @@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef KV_DIRECT let v_block_row = kv_tile + kv_block * SG_MAT_N; let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &V, v_global_offset, false, @@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); #else let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, v_block_offset + head_dim_block, false,