ggml-webgpu: add tile flash attention fallback

This commit is contained in:
Zheyuan Chen 2026-04-20 17:52:15 -07:00
parent 8bc492ebb4
commit 47e4de3169
4 changed files with 434 additions and 63 deletions

View File

@ -436,6 +436,12 @@ struct ggml_webgpu_unary_pipeline_key_hash {
/** FlashAttention */
enum ggml_webgpu_flash_attn_path : uint32_t {
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u,
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u,
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u,
};
struct ggml_webgpu_flash_attn_pipeline_key {
ggml_type kv_type;
uint32_t head_dim_qk;
@ -444,11 +450,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
uint32_t path;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
uses_logit_softcap == other.uses_logit_softcap && path == other.path;
}
};
@ -462,6 +469,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
ggml_webgpu_hash_combine(seed, key.path);
return seed;
}
};
@ -476,6 +484,43 @@ struct ggml_webgpu_flash_attn_vec_decisions {
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};
struct ggml_webgpu_flash_attn_tile_policy {
uint32_t q_tile = 4u;
uint32_t kv_granularity = 32u;
uint32_t max_kv_tile = 64u;
uint32_t wg_size = GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE;
};
inline ggml_webgpu_flash_attn_tile_policy ggml_webgpu_get_flash_attn_tile_policy(uint32_t max_subgroup_size) {
ggml_webgpu_flash_attn_tile_policy policy = {};
const uint32_t subgroup_width = std::max(1u, max_subgroup_size);
policy.q_tile = std::max(1u, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE / subgroup_width);
policy.kv_granularity = subgroup_width;
policy.max_kv_tile = std::max(64u, subgroup_width);
policy.wg_size = subgroup_width * policy.q_tile;
return policy;
}
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
// Keep conservative defaults unless this is the f16 vec-split shape family.
if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 ||
key.head_dim_qk != key.head_dim_v) {
return 1u;
}
switch (key.head_dim_qk) {
case 64:
case 192:
case 576:
return 2u;
case 96:
return 4u;
default:
return 1u;
}
}
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
const ggml_webgpu_shader_lib_context & context) {
@ -492,6 +537,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_
key.has_mask = has_mask;
key.has_sinks = has_sinks;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
key.path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX;
return key;
}
@ -2045,8 +2091,10 @@ class ggml_webgpu_shader_lib {
}
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
auto it = flash_attn_pipelines.find(key);
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
key.path = context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX
: GGML_WEBGPU_FLASH_ATTN_PATH_TILE;
auto it = flash_attn_pipelines.find(key);
if (it != flash_attn_pipelines.end()) {
return it->second;
}
@ -2094,40 +2142,56 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile =
std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
uint32_t wg_size = 0;
const char * shader_src = wgsl_flash_attn;
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
decisions->q_tile = context.sg_mat_m;
const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
const auto tile_policy = ggml_webgpu_get_flash_attn_tile_policy(context.max_subgroup_size);
q_tile = tile_policy.q_tile;
kv_tile = std::min(tile_policy.max_kv_tile, ggml_webgpu_flash_attn_max_kv_tile(context, key));
kv_tile = std::max(context.sg_mat_n, (kv_tile / context.sg_mat_n) * context.sg_mat_n);
wg_size = tile_policy.wg_size;
shader_src = wgsl_flash_attn_tile;
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size));
variant += "_tile";
} else {
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
}
if (key.kv_direct) {
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
}
auto decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
decisions->wg_size = wg_size;
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant);
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
pipeline.context = decisions;
flash_attn_pipelines[key] = pipeline;
return flash_attn_pipelines[key];
}
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
auto it = flash_attn_vec_pipelines.find(key);
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context);
key.path = GGML_WEBGPU_FLASH_ATTN_PATH_VEC;
auto it = flash_attn_vec_pipelines.find(key);
if (it != flash_attn_vec_pipelines.end()) {
return it->second;
}
@ -2185,23 +2249,7 @@ class ggml_webgpu_shader_lib {
auto decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>();
decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context);
decisions->wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
uint32_t vec_ne = 1u;
// Keep conservative defaults unless this is the f16 vec-split shape family.
if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) {
switch (key.head_dim_qk) {
case 64:
case 192:
case 576:
vec_ne = 2u;
break;
case 96:
vec_ne = 4u;
break;
default:
break;
}
}
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(key);
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size));

View File

@ -402,10 +402,27 @@ static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx,
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
return global_ctx->capabilities.supports_subgroup_matrix && (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) &&
(V->ne[0] % 4 == 0) && kv_vec_type_supported &&
(K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
}
static bool ggml_webgpu_flash_attn_use_tile(webgpu_global_context & global_ctx,
const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V) {
const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const uint32_t k_offset_elems =
(uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
const uint32_t v_offset_elems =
(uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
return global_ctx->capabilities.supports_subgroups && !global_ctx->capabilities.supports_subgroup_matrix &&
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(Q->ne[0] % 4 == 0) && (V->ne[0] % 4 == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
}
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
size_t offset = ggml_webgpu_tensor_offset(t);
return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
@ -1638,6 +1655,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
@ -1645,6 +1664,14 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V);
const bool use_tile = ggml_webgpu_flash_attn_use_tile(ctx->global_ctx, Q, K, V);
if (use_tile) {
const auto tile_policy = ggml_webgpu_get_flash_attn_tile_policy(shader_lib_ctx.max_subgroup_size);
shader_lib_ctx.supports_subgroup_matrix = false;
shader_lib_ctx.sg_mat_m = tile_policy.q_tile;
shader_lib_ctx.sg_mat_n = tile_policy.kv_granularity;
shader_lib_ctx.sg_mat_k = tile_policy.kv_granularity;
}
webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) :
ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
@ -3431,12 +3458,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
ctx->webgpu_global_ctx->capabilities.supports_subgroups =
ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups);
bool valid_subgroup_matrix_config = false;
#ifndef __EMSCRIPTEN__
// Accept f16 subgroup matrix configurations (square or non-square).
// NVIDIA GPUs typically report square configs (e.g. 16x16x16),
// while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16).
// The shaders are already parameterized to handle any M/N/K dimensions.
bool valid_subgroup_matrix_config = false;
if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
@ -3450,8 +3477,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
}
}
}
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
#endif
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
@ -3782,32 +3809,69 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
case GGML_OP_FLASH_ATTN_EXT:
{
#ifndef __EMSCRIPTEN__
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
break;
}
// Head dimensions must be divisible by subgroup matrix dimensions
if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 ||
src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 &&
(src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
break;
}
supports_op = src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
src2->type == src1->type && op->type == GGML_TYPE_F32;
if (!supports_op) {
break;
}
#ifndef __EMSCRIPTEN__
const bool kv_vec_type_supported =
src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0;
const bool use_vec = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix &&
(src0->ne[1] < 20) && (src0->ne[0] % 32 == 0) && (src2->ne[0] % 4 == 0) &&
kv_vec_type_supported && src2->type == src1->type;
const bool use_tile =
ctx->webgpu_global_ctx->capabilities.supports_subgroups &&
src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 &&
(src0->ne[0] % 4 == 0) && (src2->ne[0] % 4 == 0) &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0) &&
!use_vec && !ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
const auto tile_policy =
ggml_webgpu_get_flash_attn_tile_policy(ctx->webgpu_global_ctx->capabilities.max_subgroup_size);
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
if (use_vec) {
const bool kv_direct =
src1->type == GGML_TYPE_F16 && (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
if (use_tile) {
const bool kv_direct =
src1->type == GGML_TYPE_F16 && (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
tile_policy.q_tile, tile_policy.kv_granularity,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
supports_op = false;
break;
}
const bool kv_direct =
src1->type == GGML_TYPE_F16 && (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
(uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
if (min_bytes > limit_bytes) {
supports_op = false;
}
#else
supports_op = false;
#endif
break;
}

View File

@ -0,0 +1,261 @@
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
#define Q_TILE 4
#define KV_TILE 64
#define WG_SIZE 128
struct Params {
offset_q: u32,
offset_k: u32,
offset_v: u32,
offset_mask: u32,
offset_sinks: u32,
offset_dst: u32,
n_heads: u32,
seq_len_q: u32,
seq_len_kv: u32,
stride_q1: u32,
stride_q2: u32,
stride_q3: u32,
stride_k1: u32,
stride_k2: u32,
stride_k3: u32,
stride_v1: u32,
stride_v2: u32,
stride_v3: u32,
stride_mask3: u32,
q_per_kv: u32,
scale: f32,
max_bias: f32,
logit_softcap: f32,
n_head_log2: f32,
m0: f32,
m1: f32,
};
@group(0) @binding(0) var<storage, read> Q: array<f32>;
@group(0) @binding(1) var<storage, read> K: array<vec4<f16>>;
@group(0) @binding(2) var<storage, read> V: array<vec4<f16>>;
#if defined(MASK) && defined(SINKS)
@group(0) @binding(3) var<storage, read> mask: array<f16>;
@group(0) @binding(4) var<storage, read> sinks: array<f32>;
#define DST_BINDING 5
#define PARAMS_BINDING 6
#elif defined(MASK)
@group(0) @binding(3) var<storage, read> mask: array<f16>;
#define DST_BINDING 4
#define PARAMS_BINDING 5
#elif defined(SINKS)
@group(0) @binding(3) var<storage, read> sinks: array<f32>;
#define DST_BINDING 4
#define PARAMS_BINDING 5
#else
#define DST_BINDING 3
#define PARAMS_BINDING 4
#endif
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
const FLOAT_MIN: f32 = -1.0e9;
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
if (subgroup_size == 0u || num_subgroups < Q_TILE) {
return;
}
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
let wg_per_batch = wg_per_head * params.n_heads;
let dst2_stride = HEAD_DIM_V * params.n_heads;
let dst3_stride = dst2_stride * params.seq_len_q;
let batch_idx = wg_id.x / wg_per_batch;
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
let wg_in_batch = wg_id.x % wg_per_batch;
let head_idx = wg_in_batch / wg_per_head;
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
let k_head_idx = head_idx / params.q_per_kv;
let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2;
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
let wg_in_head = wg_in_batch % wg_per_head;
let q_row_start = wg_in_head * Q_TILE;
let global_q_row = q_row_start + subgroup_id;
let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q;
#ifdef MASK
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
#endif
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
let head = f32(head_idx);
let slope = select(1.0,
select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0),
pow(params.m0, head + 1.0),
head < params.n_head_log2),
params.max_bias > 0.0);
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let q_tile_row = elem_idx / HEAD_DIM_QK;
let q_col = elem_idx % HEAD_DIM_QK;
let head_q_row = q_row_start + q_tile_row;
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
q_shmem[elem_idx] = f16(select(
0.0,
Q[global_q_row_offset + q_col] * params.scale,
head_q_row < params.seq_len_q));
}
workgroupBarrier();
var row_max = FLOAT_MIN;
var exp_sum = 0.0;
var out_regs: array<vec4<f32>, OUT_REGS_PER_LANE>;
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
out_regs[reg_idx] = vec4<f32>(0.0);
}
let q_base = subgroup_id * HEAD_DIM_QK;
let subgroup_p_offset = subgroup_id * KV_TILE;
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size);
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
var local_scores: array<f32, SCORE_REGS_PER_LANE>;
for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) {
local_scores[slot] = FLOAT_MIN;
}
var local_max = FLOAT_MIN;
if (row_active) {
for (var slot = 0u; slot < score_slots; slot += 1u) {
let kv_local = sg_inv_id + slot * subgroup_size;
if (kv_local >= kv_count) {
continue;
}
let global_k_row = kv_tile + kv_local;
var dot_val = 0.0;
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
let q_off = q_base + chunk * 4u;
let qv = vec4<f32>(
f32(q_shmem[q_off + 0u]),
f32(q_shmem[q_off + 1u]),
f32(q_shmem[q_off + 2u]),
f32(q_shmem[q_off + 3u]));
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
dot_val += dot(qv, vec4<f32>(K[k_vec_index]));
}
#ifdef LOGIT_SOFTCAP
dot_val = params.logit_softcap * tanh(dot_val);
#endif
#ifdef MASK
let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row;
dot_val += slope * f32(mask[mask_idx]);
#endif
local_scores[slot] = dot_val;
local_max = max(local_max, dot_val);
}
}
let tile_max = subgroupMax(local_max);
let new_max = max(row_max, tile_max);
let cur_exp = exp(row_max - new_max);
exp_sum *= cur_exp;
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
out_regs[reg_idx] *= cur_exp;
}
var local_sum = 0.0;
for (var slot = 0u; slot < score_slots; slot += 1u) {
let kv_local = sg_inv_id + slot * subgroup_size;
if (row_active && kv_local < kv_count) {
let p = exp(local_scores[slot] - new_max);
p_shmem[subgroup_p_offset + kv_local] = p;
local_sum += p;
}
}
workgroupBarrier();
let tile_sum = subgroupAdd(local_sum);
exp_sum += tile_sum;
row_max = new_max;
if (row_active) {
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
let chunk = sg_inv_id + reg_idx * subgroup_size;
if (chunk >= V_CHUNKS) {
continue;
}
var acc = out_regs[reg_idx];
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
let p = p_shmem[subgroup_p_offset + kv_local];
let global_v_row = kv_tile + kv_local;
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
acc += p * vec4<f32>(V[v_vec_index]);
}
out_regs[reg_idx] = acc;
}
}
workgroupBarrier();
}
#ifdef SINKS
if (row_active) {
let sink_score = sinks[params.offset_sinks + head_idx];
let sink_max = max(row_max, sink_score);
let sink_scale = exp(row_max - sink_max);
for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) {
out_regs[reg_idx] *= sink_scale;
}
exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max);
row_max = sink_max;
}
#endif
if (row_active) {
let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
let row_base = dst_global_offset + subgroup_id * dst2_stride;
let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size);
for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) {
let chunk = sg_inv_id + reg_idx * subgroup_size;
if (chunk >= V_CHUNKS) {
continue;
}
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
}
}
}

View File

@ -1,8 +1,6 @@
diagnostic(off, chromium.subgroup_matrix_uniformity);
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32