ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (#21669)
This commit is contained in:
parent
e4fed9d08d
commit
bfd1f453cb
|
|
@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
|
||||||
|
|
||||||
#ifndef __EMSCRIPTEN__
|
#ifndef __EMSCRIPTEN__
|
||||||
// Only support square f16 matrices of size 8 or 16 for now
|
// Accept f16 subgroup matrix configurations (square or non-square).
|
||||||
|
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
|
||||||
|
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
|
||||||
|
// The shaders are already parameterized to handle any M/N/K dimensions.
|
||||||
bool valid_subgroup_matrix_config = false;
|
bool valid_subgroup_matrix_config = false;
|
||||||
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
|
||||||
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
|
||||||
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
|
||||||
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
|
if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
||||||
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
|
|
||||||
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
|
||||||
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
|
ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
|
||||||
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
|
ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
|
||||||
|
|
@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
// Head dimensions must be divisible by subgroup matrix dimensions
|
||||||
|
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
|
||||||
|
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
// Head dimensions must fit in workgroup memory with minimum tile sizes
|
||||||
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||||
const bool has_mask = op->src[3] != nullptr;
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
|
|
|
||||||
|
|
@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
#endif
|
#endif
|
||||||
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
|
||||||
let inter_offset = kv_block * SG_MAT_N;
|
let inter_offset = kv_block * SG_MAT_N;
|
||||||
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
|
var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE);
|
||||||
|
|
||||||
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK);
|
||||||
|
|
||||||
#ifdef KV_DIRECT
|
#ifdef KV_DIRECT
|
||||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1);
|
||||||
#else
|
#else
|
||||||
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
var t: u32 = 1u;
|
var t: u32 = 1u;
|
||||||
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
|
for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
|
||||||
let h0 = t * SG_MAT_K;
|
let h0 = t * SG_MAT_K;
|
||||||
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK);
|
||||||
#ifdef KV_DIRECT
|
#ifdef KV_DIRECT
|
||||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h0, true, params.stride_k1);
|
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1);
|
||||||
#else
|
#else
|
||||||
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
|
||||||
#endif
|
#endif
|
||||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||||
q_cur = q0;
|
q_cur = q0;
|
||||||
k_cur = k0;
|
k_cur = k0;
|
||||||
|
|
||||||
let h1 = (t + 1u) * SG_MAT_K;
|
let h1 = (t + 1u) * SG_MAT_K;
|
||||||
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK);
|
||||||
#ifdef KV_DIRECT
|
#ifdef KV_DIRECT
|
||||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h1, true, params.stride_k1);
|
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1);
|
||||||
#else
|
#else
|
||||||
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
|
||||||
#endif
|
#endif
|
||||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||||
q_cur = q1g;
|
q_cur = q1g;
|
||||||
|
|
@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
// handle odd tail
|
// handle odd tail
|
||||||
if (t < HEAD_DIM_QK / SG_MAT_K) {
|
if (t < HEAD_DIM_QK / SG_MAT_K) {
|
||||||
let h = t * SG_MAT_K;
|
let h = t * SG_MAT_K;
|
||||||
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(&q_shmem, h, false, HEAD_DIM_QK);
|
var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK);
|
||||||
#ifdef KV_DIRECT
|
#ifdef KV_DIRECT
|
||||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&K, k_global_offset + h, true, params.stride_k1);
|
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1);
|
||||||
#else
|
#else
|
||||||
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
|
||||||
#endif
|
#endif
|
||||||
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
|
||||||
q_cur = qn;
|
q_cur = qn;
|
||||||
|
|
@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
head_dim_block < HEAD_DIM_V;
|
head_dim_block < HEAD_DIM_V;
|
||||||
head_dim_block += num_subgroups * SG_MAT_N) {
|
head_dim_block += num_subgroups * SG_MAT_N) {
|
||||||
// load O submatrix from shared memory
|
// load O submatrix from shared memory
|
||||||
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
|
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(
|
||||||
&o_shmem,
|
&o_shmem,
|
||||||
head_dim_block,
|
head_dim_block,
|
||||||
false,
|
false,
|
||||||
|
|
@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
);
|
);
|
||||||
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
|
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
|
||||||
let p_offset = kv_block * SG_MAT_N;
|
let p_offset = kv_block * SG_MAT_N;
|
||||||
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
|
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(
|
||||||
&inter_shmem,
|
&inter_shmem,
|
||||||
p_offset,
|
p_offset,
|
||||||
false,
|
false,
|
||||||
|
|
@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
#ifdef KV_DIRECT
|
#ifdef KV_DIRECT
|
||||||
let v_block_row = kv_tile + kv_block * SG_MAT_N;
|
let v_block_row = kv_tile + kv_block * SG_MAT_N;
|
||||||
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
|
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
|
||||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
|
||||||
&V,
|
&V,
|
||||||
v_global_offset,
|
v_global_offset,
|
||||||
false,
|
false,
|
||||||
|
|
@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||||
);
|
);
|
||||||
#else
|
#else
|
||||||
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
|
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
|
||||||
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
|
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(
|
||||||
&kv_shmem,
|
&kv_shmem,
|
||||||
v_block_offset + head_dim_block,
|
v_block_offset + head_dim_block,
|
||||||
false,
|
false,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue