From a1cfb645307edc61a89e41557f290f441043d3c2 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 2 Apr 2026 10:40:42 -0700 Subject: [PATCH] ggml-webgpu: add vectorized flash attention (#20709) * naive vectorized version * add vectorized flash attention * update vec version * remove unused path and shader * remove unused helper functions * add comments * remove pad path * ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization * change back to vec4 * enable multi split * enable vec path when: - Q->ne[1] < 20 - Q->ne[0] % 32 == 0 - V->ne[0] % 4 == 0 - K->type == f16 * update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select * enable vec path for q4 and q8 * flash-attn vec nwg=1 fast path (skip tmp/reduce staging) * use packed f16 K loads in flash-attn vec split * use packed f16 K loads in flash-attn vec split on host side * tune flash-attn vec f16 VEC_NE by head dim * cleanup * cleanup * keep host side clean * cleanup host side * change back to original host wait/submit behavior * formatting * reverted param-buffer pool r ecfactor * add helper functions * ggml-webgpu: move flash-attn vec pipeline caching back into shader lib * ggml-webgpu: remove duplicate functions * ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation * ggml-webgpu: revert unrelated change * ggml-webgpu: revert deleted comment * disable uniformity check * remove unnecessary change * Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl * Update ggml/src/ggml-webgpu/ggml-webgpu.cpp --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 230 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 323 +++++++- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 105 +++ .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 78 ++ .../wgsl-shaders/flash_attn_vec_split.wgsl | 729 ++++++++++++++++++ 5 files changed, 1412 insertions(+), 53 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index a194ce84e2..1c56c68931 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + std::shared_ptr decisions; +}; + struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t block_size; uint32_t tokens_per_wg; @@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; + bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; } }; @@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; @@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions { uint32_t wg_size = 0; }; +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + // Head-dim specializations used by the tuned vec f16 path. + switch (key.head_dim_qk) { + case 64: return 2u; + case 96: return 4u; + case 128: return 1u; + case 192: return 2u; + case 576: return 2u; + default: return 1u; + } +} + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { + uint32_t head_dim_v; + uint32_t wg_size; +}; + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.wg_size); + return seed; + } +}; + +inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; +} + +struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + +struct ggml_webgpu_flash_attn_blk_pipeline_key { + uint32_t q_tile; + uint32_t kv_tile; + + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { + return q_tile == other.q_tile && kv_tile == other.kv_tile; + } +}; + +struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_tile); + ggml_webgpu_hash_combine(seed, key.kv_tile); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_blk_shader_lib_context { + ggml_webgpu_flash_attn_blk_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); + variant += std::string("_qt") + std::to_string(context.key.q_tile); + + defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); + variant += std::string("_kvt") + std::to_string(context.key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_reduce_pipelines; + std::unordered_map + flash_attn_blk_pipelines; std::unordered_map @@ -1673,24 +1798,8 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - - bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % context.sg_mat_n == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = context.src1->type, - .head_dim_qk = (uint32_t) context.src0->ne[0], - .head_dim_v = (uint32_t) context.src2->ne[0], - .kv_direct = kv_direct, - .has_mask = has_mask, - .has_sinks = has_sinks, - .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, - }; - - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { + auto it = flash_attn_pipelines.find(context.key); if (it != flash_attn_pipelines.end()) { return it->second; } @@ -1698,7 +1807,7 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "flash_attn"; - switch (key.kv_type) { + switch (context.key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(key.kv_type); + variant += std::string("_") + ggml_type_name(context.key.kv_type); - if (key.has_mask) { + if (context.key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (key.has_sinks) { + if (context.key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (key.uses_logit_softcap) { + if (context.key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (key.kv_direct) { + if (context.key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (context.key.has_mask && context.key.use_vec) { + defines.push_back("BLK"); + variant += "_blk"; + } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; + uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, - context.wg_mem_limit_bytes, context.max_subgroup_size }), + std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (key.kv_direct) { + if (context.key.use_vec) { + q_tile = 1; + kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + } + if (context.key.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } @@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + uint32_t wg_size = 0; + if (context.key.use_vec) { + wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + } else { + wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); auto decisions = std::make_shared(); decisions->q_tile = q_tile; decisions->kv_tile = kv_tile; decisions->wg_size = wg_size; + pipeline.context = decisions; + flash_attn_pipelines[context.key] = pipeline; + return flash_attn_pipelines[context.key]; + } - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - flash_attn_pipelines[key] = pipeline; - return flash_attn_pipelines[key]; + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + auto it = flash_attn_blk_pipelines.find(context.key); + if (it != flash_attn_blk_pipelines.end()) { + return it->second; + } + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_blk_pipelines[context.key] = pipeline; + return flash_attn_blk_pipelines[context.key]; + } + + webgpu_pipeline get_flash_attn_vec_reduce_pipeline( + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + auto it = flash_attn_vec_reduce_pipelines.find(context.key); + if (it != flash_attn_vec_reduce_pipelines.end()) { + return it->second; + } + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_vec_reduce_pipelines[context.key] = pipeline; + return flash_attn_vec_reduce_pipelines[context.key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1aa15b0507..e53281bfbb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < params_bufs_list.size(); i++) { ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } - #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { @@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } -#ifndef __EMSCRIPTEN__ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = Q, - .src1 = K, - .src2 = V, - .src3 = mask, - .src4 = sinks, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + + const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && + (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); + const uint32_t vec_nwg_cap = + std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .use_vec = use_vec, + }; + + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + + wgpu::Buffer blk_buf = {}; + uint64_t blk_size_bytes = 0; + uint32_t blk_nblk0 = 0; + uint32_t blk_nblk1 = 0; + uint32_t blk_batch_count = 0; + + if (use_vec) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } + + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (use_blk) { + GGML_ASSERT(has_mask); + + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { + .key = + { + .q_tile = decisions->q_tile, + .kv_tile = decisions->kv_tile, + }, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, + { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (use_blk) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + if (use_blk) { + split_entries.push_back( + { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); + } + split_entries.push_back( + { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, + std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = + { + .head_dim_v = (uint32_t) V->ne[0], + .wg_size = reduce_wg_size, + }, + .max_wg_size = reduce_wg_size, + }; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + } + + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + if (use_blk) { + pipelines.push_back(blk_pipeline); + params_list.push_back(std::move(blk_params)); + entries_list.push_back(std::move(blk_entries)); + workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + } + pipelines.push_back(pipeline); + params_list.push_back(std::move(split_params)); + entries_list.push_back(std::move(split_entries)); + workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + if (use_vec_reduce) { + pipelines.push_back(reduce_pipeline); + params_list.push_back(std::move(reduce_params)); + entries_list.push_back(std::move(reduce_entries)); + workgroups_list.push_back({ (uint32_t) nrows, 1u }); + } + + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + entries_list, workgroups_list); + } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -#endif static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector subs; uint32_t num_batched_kernels = 0; bool contains_set_rows = false; - for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; @@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const ggml_tensor * sinks = tensor->src[4]; + if (Q && K && V) { + GGML_UNUSED(sinks); + const bool kv_direct = (K->type == GGML_TYPE_F16) && + (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = + (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (V->type == K->type); + if (use_vec) { + const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + const size_t limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = + (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!kv_direct) { + bytes_per_kv += std::max(Q->ne[0], V->ne[0]); + } + if (mask != nullptr) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + uint32_t kv_tile = + ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + if (kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= sg_mat_n; + } + } + + const uint32_t vec_nwg_cap = std::max( + 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + + const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2( + (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = + (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + } + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl new file mode 100644 index 0000000000..82d072be73 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -0,0 +1,105 @@ +diagnostic(off, subgroup_uniformity); +enable f16; + +#define Q_TILE 1 +#define KV_TILE 32 +#define WG_SIZE 32 + +struct Params { + offset_mask: u32, + seq_len_q: u32, + seq_len_kv: u32, + stride_mask3: u32, + // Number of KV blocks and Q blocks per batch. + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + nblk0: u32, + nblk1: u32, +}; + +@group(0) @binding(0) var mask: array; +@group(0) @binding(1) var blk: array; +@group(0) @binding(2) var params: Params; + +const MASK_MIN: f32 = -65504.0; +const MASK_MAX: f32 = 65504.0; +var wg_min: array; +var wg_max: array; +var wg_any: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + // Dispatch mapping: + // - x indexes KV blocks + // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk + let kv_blk = wg_id.x; + let y = wg_id.y; + let q_blk = y % params.nblk1; + let batch_idx = y / params.nblk1; + if (kv_blk >= params.nblk0) { + return; + } + + let q_start = q_blk * Q_TILE; + let k_start = kv_blk * KV_TILE; + + let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + + // We keep min/max to classify: + // - fully masked (max <= MASK_MIN) + // - all-zero mask (min == 0 && max == 0) + // - mixed/general mask + var local_min = MASK_MAX; + var local_max = -MASK_MAX; + var local_any = 0u; + + for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { + let q_row = q_start + q_rel; + if (q_row >= params.seq_len_q) { + continue; + } + let row_base = mask_batch_base + q_row * params.seq_len_kv; + for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { + let k_col = k_start + k_rel; + if (k_col >= params.seq_len_kv) { + continue; + } + let mv = f32(mask[row_base + k_col]); + local_min = min(local_min, mv); + local_max = max(local_max, mv); + local_any = 1u; + } + } + + wg_min[local_id.x] = local_min; + wg_max[local_id.x] = local_max; + wg_any[local_id.x] = local_any; + workgroupBarrier(); + + // Thread 0 writes one state per block. + if (local_id.x == 0u) { + var mmin = wg_min[0]; + var mmax = wg_max[0]; + var many = wg_any[0]; + for (var i = 1u; i < WG_SIZE; i += 1u) { + mmin = min(mmin, wg_min[i]); + mmax = max(mmax, wg_max[i]); + many = max(many, wg_any[i]); + } + + var state = 0u; + if (many != 0u) { + if (mmax <= MASK_MIN) { + state = 0u; + } else if (mmin == 0.0 && mmax == 0.0) { + state = 2u; + } else { + state = 1u; + } + } + + let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk; + blk[blk_idx] = state; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl new file mode 100644 index 0000000000..9a0de82a56 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -0,0 +1,78 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +// Default values +#define HEAD_DIM_V 64 +#define WG_SIZE 128 + +struct Params { + nrows: u32, + seq_len_q: u32, + n_heads: u32, + offset_dst: u32, + nwg: u32, + tmp_data_base: u32, + tmp_stats_base: u32, +}; + +@group(0) @binding(0) var tmp: array; +@group(0) @binding(1) var dst: array>; +@group(0) @binding(2) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + let rid = wg_id.x; + if (rid >= params.nrows) { + return; + } + + let rows_per_batch = params.n_heads * params.seq_len_q; + let batch_idx = rid / rows_per_batch; + let rem = rid % rows_per_batch; + let head_idx = rem / params.seq_len_q; + let q_row = rem % params.seq_len_q; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + + let thread = sg_inv_id; + if (params.nwg > subgroup_size) { + return; + } + + let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); + let active_thread = thread < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread); + let m = subgroupMax(mi); + let ms = select(0.0, exp(mi - m), active_thread); + let s = subgroupAdd(si * ms); + let inv_s = select(0.0, 1.0 / s, s != 0.0); + + let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); + for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { + var weighted = vec4(0.0, 0.0, 0.0, 0.0); + if (active_thread) { + let src = row_tmp_base + thread * HEAD_DIM_V + elem_base; + weighted = vec4(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; + } + + let sum_x = subgroupAdd(weighted.x); + let sum_y = subgroupAdd(weighted.y); + let sum_z = subgroupAdd(weighted.z); + let sum_w = subgroupAdd(weighted.w); + + if (thread == 0u) { + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl new file mode 100644 index 0000000000..a52575871a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -0,0 +1,729 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + + +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 +#ifndef VEC_NE +#define VEC_NE 4u +#endif + +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +#if defined(KV_Q4_0) +#define NQ 16 +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + +#ifdef BLK + blk_base: u32, + blk_nblk0: u32, + blk_nblk1: u32, +#endif + + tmp_data_base: u32, + tmp_stats_base: u32, + nwg: u32, +}; + +@group(0) @binding(0) var Q: array; +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(2) var V: array; +#else +@group(0) @binding(2) var V: array>; +#endif +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif + +#ifdef BLK +@group(0) @binding(BLK_BINDING) var blk: array; +#endif +@group(0) @binding(TMP_BINDING) var tmp: array; +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; +var blk_state_wg: u32; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + if (apply_mask) { + var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + v += select(mask_val, slope * mask_val, has_bias); + } +#endif + return v; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let iwg = wg_id.x % params.nwg; + let base_wg_id = wg_id.x / params.nwg; + + // batch index + let batch_idx = base_wg_id / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let wg_in_batch = base_wg_id % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let head = f32(head_idx); + let has_bias = params.max_bias > 0.0; + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { +#ifdef BLK + let q_blk = q_row_start / Q_TILE; + let kv_blk = kv_tile / KV_TILE; + let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; + let blk_state_local = blk[blk_idx]; +#else + let blk_state_local = 1u; +#endif + if (local_id.x == 0u) { + blk_state_wg = blk_state_local; + } + workgroupBarrier(); + let blk_state = blk_state_wg; + let skip_tile = blk_state == 0u; + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = f16(0.0); + } + + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(k4.x); + kv_shmem[elem_idx + 1u] = f16(k4.y); + kv_shmem[elem_idx + 2u] = f16(k4.z); + kv_shmem[elem_idx + 3u] = f16(k4.w); + } +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + if (!skip_tile) { + let num_of_threads = subgroup_size / VEC_NE; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + continue; + } + let local_q_row_offset = q_tile_row * HEAD_DIM_QK; + + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = local_q_row_offset + i * 4u; + + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); +#ifdef KV_DIRECT + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + let kv = vec4(K[idx >> 2u]); +#else + let idx = kv_idx * HEAD_DIM_QK + (i * 4u); + let kv = vec4( + f32(kv_shmem[idx + 0u]), + f32(kv_shmem[idx + 1u]), + f32(kv_shmem[idx + 2u]), + f32(kv_shmem[idx + 3u])); +#endif + partial_sum += dot(qv, kv); + } + } + var sum = partial_sum; + // Reduce over tx threads (NL) for this ty stripe. + var tx_delta = num_of_threads >> 1u; + loop { + if (tx_delta == 0u) { + break; + } + let sh = subgroupShuffleDown(sum, tx_delta); + if (tx < tx_delta) { + sum += sh; + } + tx_delta >>= 1u; + } + + let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); + if (tx == 0u && kv_valid) { + let dst_idx = q_tile_row * KV_TILE + kv_idx; + inter_shmem[dst_idx] = f16(sum_bcast); + } + } + } + } + + +#ifdef MASK + let apply_mask = !skip_tile && (blk_state != 2u); + if (apply_mask) { + // load mask tile into shared memory for this KV block + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } + } +#else + let apply_mask = false; +#endif + + workgroupBarrier(); + + // online softmax + if (!skip_tile) { + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + } + + // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(v4.x); + kv_shmem[elem_idx + 1u] = f16(v4.y); + kv_shmem[elem_idx + 2u] = f16(v4.z); + kv_shmem[elem_idx + 3u] = f16(v4.w); + } +#endif + + workgroupBarrier(); + + if (!skip_tile) { + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + let ne_threads : u32 = VEC_NE; + let nl_threads = max(1u, subgroup_size / ne_threads); + let tx_pv = sg_inv_id % nl_threads; + let ty_pv = sg_inv_id / nl_threads; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { + var lo = vec4(0.0, 0.0, 0.0, 0.0); + for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { + let kv_idx = cc * ne_threads + ty_pv; + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + + let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); +#ifdef KV_DIRECT + let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + let v4 = vec4(V[v_idx >> 2u]); +#else + let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; + let v4 = vec4( + f32(kv_shmem[v_idx + 0u]), + f32(kv_shmem[v_idx + 1u]), + f32(kv_shmem[v_idx + 2u]), + f32(kv_shmem[v_idx + 3u])); +#endif + lo += p * v4; + } + + var lo_x = lo.x; + var lo_y = lo.y; + var lo_z = lo.z; + var lo_w = lo.w; + // Reduce over ty threads (NE) for this tx thread. + var ty_delta = ne_threads >> 1u; + loop { + if (ty_delta == 0u) { + break; + } + let thread_delta = ty_delta * nl_threads; + let shx = subgroupShuffleDown(lo_x, thread_delta); + let shy = subgroupShuffleDown(lo_y, thread_delta); + let shz = subgroupShuffleDown(lo_z, thread_delta); + let shw = subgroupShuffleDown(lo_w, thread_delta); + if (ty_pv < ty_delta) { + lo_x += shx; + lo_y += shy; + lo_z += shz; + lo_w += shw; + } + ty_delta >>= 1u; + } + + if (ty_pv == 0u) { + let elem_base = vec_col * 4u; + let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; + o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); + o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); + o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); + o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + } + } + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // Sinks are global terms and must be applied exactly once across split workgroups. + if (iwg == 0u) { + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = new_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); + } + } + workgroupBarrier(); + } +#endif + let rows_per_batch = params.n_heads * params.seq_len_q; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } + + if (params.nwg == 1u) { + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base: u32 = + params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } else { + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[i0]); + tmp[tbase + 1u] = f32(o_shmem[i1]); + tmp[tbase + 2u] = f32(o_shmem[i2]); + tmp[tbase + 3u] = f32(o_shmem[i3]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + } + } + } +}