diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 449eae808e..af46d86b2e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -436,6 +436,12 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ +enum ggml_webgpu_flash_attn_path : uint32_t { + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, +}; + struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; @@ -444,11 +450,12 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; + uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; @@ -462,6 +469,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.path); return seed; } }; @@ -476,6 +484,43 @@ struct ggml_webgpu_flash_attn_vec_decisions { uint32_t kv_tile = 0; uint32_t wg_size = 0; }; +struct ggml_webgpu_flash_attn_tile_policy { + uint32_t q_tile = 4u; + uint32_t kv_granularity = 32u; + uint32_t max_kv_tile = 64u; + uint32_t wg_size = GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE; +}; + +inline ggml_webgpu_flash_attn_tile_policy ggml_webgpu_get_flash_attn_tile_policy(uint32_t max_subgroup_size) { + ggml_webgpu_flash_attn_tile_policy policy = {}; + const uint32_t subgroup_width = std::max(1u, max_subgroup_size); + + policy.q_tile = std::max(1u, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE / subgroup_width); + policy.kv_granularity = subgroup_width; + policy.max_kv_tile = std::max(64u, subgroup_width); + policy.wg_size = subgroup_width * policy.q_tile; + + return policy; +} + +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || + key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + return 2u; + case 96: + return 4u; + default: + return 1u; + } +} inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( const ggml_webgpu_shader_lib_context & context) { @@ -492,6 +537,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + key.path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; return key; } @@ -2045,8 +2091,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_pipelines.find(key); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + key.path = context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX + : GGML_WEBGPU_FLASH_ATTN_PATH_TILE; + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } @@ -2094,40 +2142,56 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = + std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + uint32_t wg_size = 0; + const char * shader_src = wgsl_flash_attn; - auto decisions = std::make_shared(); - decisions->q_tile = context.sg_mat_m; - - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const auto tile_policy = ggml_webgpu_get_flash_attn_tile_policy(context.max_subgroup_size); + q_tile = tile_policy.q_tile; + kv_tile = std::min(tile_policy.max_kv_tile, ggml_webgpu_flash_attn_max_kv_tile(context, key)); + kv_tile = std::max(context.sg_mat_n, (kv_tile / context.sg_mat_n) * context.sg_mat_n); + wg_size = tile_policy.wg_size; + shader_src = wgsl_flash_attn_tile; + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + variant += "_tile"; + } else { + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + } if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; decisions->kv_tile = kv_tile; - decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + decisions->wg_size = wg_size; - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); pipeline.context = decisions; flash_attn_pipelines[key] = pipeline; return flash_attn_pipelines[key]; } webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_vec_pipelines.find(key); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + key.path = GGML_WEBGPU_FLASH_ATTN_PATH_VEC; + auto it = flash_attn_vec_pipelines.find(key); if (it != flash_attn_vec_pipelines.end()) { return it->second; } @@ -2185,23 +2249,7 @@ class ggml_webgpu_shader_lib { auto decisions = std::make_shared(); decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - uint32_t vec_ne = 1u; - - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - vec_ne = 2u; - break; - case 96: - vec_ne = 4u; - break; - default: - break; - } - } + const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(key); defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index acc486cfdd..b9fd0d0f37 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -402,10 +402,27 @@ static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + return global_ctx->capabilities.supports_subgroup_matrix && (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && + (V->ne[0] % 4 == 0) && kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); } +static bool ggml_webgpu_flash_attn_use_tile(webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + + return global_ctx->capabilities.supports_subgroups && !global_ctx->capabilities.supports_subgroup_matrix && + K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && + (Q->ne[0] % 4 == 0) && (V->ne[0] % 4 == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); @@ -1638,6 +1655,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.src3 = mask; shader_lib_ctx.src4 = sinks; shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; @@ -1645,6 +1664,14 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); + const bool use_tile = ggml_webgpu_flash_attn_use_tile(ctx->global_ctx, Q, K, V); + if (use_tile) { + const auto tile_policy = ggml_webgpu_get_flash_attn_tile_policy(shader_lib_ctx.max_subgroup_size); + shader_lib_ctx.supports_subgroup_matrix = false; + shader_lib_ctx.sg_mat_m = tile_policy.q_tile; + shader_lib_ctx.sg_mat_n = tile_policy.kv_granularity; + shader_lib_ctx.sg_mat_k = tile_policy.kv_granularity; + } webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); @@ -3431,12 +3458,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). // NVIDIA GPUs typically report square configs (e.g. 16x16x16), // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). // The shaders are already parameterized to handle any M/N/K dimensions. - bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; @@ -3450,8 +3477,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } } } - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. @@ -3782,32 +3809,69 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { -#ifndef __EMSCRIPTEN__ - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - break; - } - // Head dimensions must be divisible by subgroup matrix dimensions - if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || - src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { - break; - } - // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); - if (min_bytes > limit_bytes) { - break; - } - supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && src2->type == src1->type && op->type == GGML_TYPE_F32; + if (!supports_op) { + break; + } +#ifndef __EMSCRIPTEN__ + const bool kv_vec_type_supported = + src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0; + const bool use_vec = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix && + (src0->ne[1] < 20) && (src0->ne[0] % 32 == 0) && (src2->ne[0] % 4 == 0) && + kv_vec_type_supported && src2->type == src1->type; + const bool use_tile = + ctx->webgpu_global_ctx->capabilities.supports_subgroups && + src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && + (src0->ne[0] % 4 == 0) && (src2->ne[0] % 4 == 0) && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0) && + !use_vec && !ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + const auto tile_policy = + ggml_webgpu_get_flash_attn_tile_policy(ctx->webgpu_global_ctx->capabilities.max_subgroup_size); + const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + if (use_vec) { + const bool kv_direct = + src1->type == GGML_TYPE_F16 && (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } + break; + } + + if (use_tile) { + const bool kv_direct = + src1->type == GGML_TYPE_F16 && (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + tile_policy.q_tile, tile_policy.kv_granularity, + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } + break; + } + + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = false; + break; + } + const bool kv_direct = + src1->type == GGML_TYPE_F16 && (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } +#else + supports_op = false; #endif break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl new file mode 100644 index 0000000000..0725cf23bb --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -0,0 +1,261 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 +#define Q_TILE 4 +#define KV_TILE 64 +#define WG_SIZE 128 + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + q_per_kv: u32, + + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; + +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; +const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; +const V_CHUNKS: u32 = HEAD_DIM_V / 4u; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; + +var q_shmem: array; +var p_shmem: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + if (subgroup_size == 0u || num_subgroups < Q_TILE) { + return; + } + + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + let global_q_row = q_row_start + subgroup_id; + let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q; + +#ifdef MASK + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); + + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_tile_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_tile_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col] * params.scale, + head_q_row < params.seq_len_q)); + } + + workgroupBarrier(); + + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + var out_regs: array, OUT_REGS_PER_LANE>; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] = vec4(0.0); + } + + let q_base = subgroup_id * HEAD_DIM_QK; + let subgroup_p_offset = subgroup_id * KV_TILE; + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); + let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size); + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + var local_scores: array; + for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) { + local_scores[slot] = FLOAT_MIN; + } + + var local_max = FLOAT_MIN; + if (row_active) { + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (kv_local >= kv_count) { + continue; + } + + let global_k_row = kv_tile + kv_local; + var dot_val = 0.0; + for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { + let q_off = q_base + chunk * 4u; + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + dot_val += dot(qv, vec4(K[k_vec_index])); + } +#ifdef LOGIT_SOFTCAP + dot_val = params.logit_softcap * tanh(dot_val); +#endif +#ifdef MASK + let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row; + dot_val += slope * f32(mask[mask_idx]); +#endif + local_scores[slot] = dot_val; + local_max = max(local_max, dot_val); + } + } + + let tile_max = subgroupMax(local_max); + let new_max = max(row_max, tile_max); + let cur_exp = exp(row_max - new_max); + exp_sum *= cur_exp; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= cur_exp; + } + + var local_sum = 0.0; + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (row_active && kv_local < kv_count) { + let p = exp(local_scores[slot] - new_max); + p_shmem[subgroup_p_offset + kv_local] = p; + local_sum += p; + } + } + + workgroupBarrier(); + + let tile_sum = subgroupAdd(local_sum); + exp_sum += tile_sum; + row_max = new_max; + + if (row_active) { + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + + var acc = out_regs[reg_idx]; + for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { + let p = p_shmem[subgroup_p_offset + kv_local]; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + acc += p * vec4(V[v_vec_index]); + } + out_regs[reg_idx] = acc; + } + } + + workgroupBarrier(); + } + +#ifdef SINKS + if (row_active) { + let sink_score = sinks[params.offset_sinks + head_idx]; + let sink_max = max(row_max, sink_score); + let sink_scale = exp(row_max - sink_max); + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= sink_scale; + } + exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max); + row_max = sink_max; + } +#endif + + if (row_active) { + let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base = dst_global_offset + subgroup_id * dst2_stride; + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + let dst_vec_index = (row_base + chunk * 4u) >> 2u; + dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index a52575871a..7b2704ba9c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -1,8 +1,6 @@ -diagnostic(off, chromium.subgroup_matrix_uniformity); diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32