ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (#21669)

This commit is contained in:
Rithik Sharma 2026-04-10 10:52:38 -07:00 committed by GitHub
parent e4fed9d08d
commit bfd1f453cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 20 deletions

View File

@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
#ifndef __EMSCRIPTEN__ #ifndef __EMSCRIPTEN__
// Only support square f16 matrices of size 8 or 16 for now // Accept f16 subgroup matrix configurations (square or non-square).
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
// The shaders are already parameterized to handle any M/N/K dimensions.
bool valid_subgroup_matrix_config = false; bool valid_subgroup_matrix_config = false;
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
break; break;
} }
// Head dimensions must be divisible by subgroup matrix dimensions
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes // Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr; const bool has_mask = op->src[3] != nullptr;

View File

@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
#endif #endif
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
let inter_offset = kv_block * SG_MAT_N; let inter_offset = kv_block * SG_MAT_N;
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE); var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE);
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK); var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK);
#ifdef KV_DIRECT #ifdef KV_DIRECT
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1); var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1);
#else #else
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
#endif #endif
var t: u32 = 1u; var t: u32 = 1u;
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
let h0 = t * SG_MAT_K; let h0 = t * SG_MAT_K;
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK); var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK);
#ifdef KV_DIRECT #ifdef KV_DIRECT
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1); var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1);
#else #else
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
#endif #endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q0; q_cur = q0;
k_cur = k0; k_cur = k0;
let h1 = (t + 1u) * SG_MAT_K; let h1 = (t + 1u) * SG_MAT_K;
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK); var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK);
#ifdef KV_DIRECT #ifdef KV_DIRECT
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1); var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1);
#else #else
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
#endif #endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = q1g; q_cur = q1g;
@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
// handle odd tail // handle odd tail
if (t < HEAD_DIM_QK / SG_MAT_K) { if (t < HEAD_DIM_QK / SG_MAT_K) {
let h = t * SG_MAT_K; let h = t * SG_MAT_K;
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK); var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK);
#ifdef KV_DIRECT #ifdef KV_DIRECT
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1); var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1);
#else #else
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
#endif #endif
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
q_cur = qn; q_cur = qn;
@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
head_dim_block < HEAD_DIM_V; head_dim_block < HEAD_DIM_V;
head_dim_block += num_subgroups * SG_MAT_N) { head_dim_block += num_subgroups * SG_MAT_N) {
// load O submatrix from shared memory // load O submatrix from shared memory
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>( var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(
&o_shmem, &o_shmem,
head_dim_block, head_dim_block,
false, false,
@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
); );
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
let p_offset = kv_block * SG_MAT_N; let p_offset = kv_block * SG_MAT_N;
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>( var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(
&inter_shmem, &inter_shmem,
p_offset, p_offset,
false, false,
@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
#ifdef KV_DIRECT #ifdef KV_DIRECT
let v_block_row = kv_tile + kv_block * SG_MAT_N; let v_block_row = kv_tile + kv_block * SG_MAT_N;
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>( var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
&V, &V,
v_global_offset, v_global_offset,
false, false,
@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
); );
#else #else
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>( var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
&kv_shmem, &kv_shmem,
v_block_offset + head_dim_block, v_block_offset + head_dim_block,
false, false,