From 15bff84bf56651d6f991f166a2bf0f362996f7f9 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 8 Jan 2026 08:23:39 -0800 Subject: [PATCH] ggml webgpu: initial flashattention implementation (#18610) * FlashAttention (#13) * Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though * neg passes backend test * unary operators pass ggml tests * rms_norm double declaration bug atoned * abides by editor-config * removed vestigial files * fixed autoconfig * All operators (inlcluding xielu) working * removed unnecesarry checking if node->src[1] exists for unary operators * responded and dealt with PR comments * implemented REPL_Template support and removed bug in unary operators kernel * formatted embed wgsl and ggml-webgpu.cpp * Faster tensors (#8) Add fast matrix and matrix/vector multiplication. * Use map for shader replacements instead of pair of strings * Wasm (#9) * webgpu : fix build on emscripten * more debugging stuff * test-backend-ops: force single thread on wasm * fix single-thread case for init_tensor_uniform * use jspi * add pthread * test: remember to set n_thread for cpu backend * Add buffer label and enable dawn-specific toggles to turn off some checks * Intermediate state * Fast working f16/f32 vec4 * Working float fast mul mat * Clean up naming of mul_mat to match logical model, start work on q mul_mat * Setup for subgroup matrix mat mul * Basic working subgroup matrix * Working subgroup matrix tiling * Handle weirder sg matrix sizes (but still % sg matrix size) * Working start to gemv * working f16 accumulation with shared memory staging * Print out available subgroup matrix configurations * Vectorize dst stores for sg matrix shader * Gemv working scalar * Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Comment on dawn toggles * Working subgroup matrix code for (semi)generic sizes * Remove some comments * Cleanup code * Update dawn version and move to portable subgroup size * Try to fix new dawn release * Update subgroup size comment * Only check for subgroup matrix configs if they are supported * Add toggles for subgroup matrix/f16 support on nvidia+vulkan * Make row/col naming consistent * Refactor shared memory loading * Move sg matrix stores to correct file * Working q4_0 * Formatting * Work with emscripten builds * Fix test-backend-ops emscripten for f16/quantized types * Use emscripten memory64 to support get_memory * Add build flags and try ci --------- Co-authored-by: Xuan Son Nguyen * Remove extra whitespace * Move wasm single-thread logic out of test-backend-ops for cpu backend * Disable multiple threads for emscripten single-thread builds in ggml_graph_plan * Refactored pipelines and workgroup calculations (#10) * refactored pipelines * refactored workgroup calculation * removed commented out block of prior maps * Clean up ceiling division pattern --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Start work on flash attention * Shader structure set up (many bugs still) * debugging * Working first test * Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32 * Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling * Start work on integrating pre-wgsl * Separate structs/initial shader compilation library into separate files * Work on compilation choices for flashattention * Work on subgroup matrix/tile size portability * subgroup size agnostic online softmax * Cleanups, quantization types * more cleanup * fix wasm build * Refactor flashattention to increase parallelism, use direct loads for KV in somce cases * Checkpoint * formatting * Update to account for default kv cache padding * formatting shader * Add workflow for ggml-ci webgpu * Try passing absolute path to dawn in ggml-ci * Avoid error on device destruction, add todos for proper cleanup * Fix unused warning * Forgot one parameter unused * Move some flashattn computation to f32 for correctness --- .github/workflows/build.yml | 44 +- ci/run.sh | 15 +- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 169 ++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 288 ++++++- ggml/src/ggml-webgpu/pre_wgsl.hpp | 778 ++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 591 +++++++++++++ 6 files changed, 1838 insertions(+), 47 deletions(-) create mode 100644 ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp create mode 100644 ggml/src/ggml-webgpu/pre_wgsl.hpp create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 85601b3712..446a3750d7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -152,13 +152,13 @@ jobs: DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip" - echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" curl -L -o artifact.zip \ - "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" mkdir dawn unzip artifact.zip - tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1 + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -532,13 +532,13 @@ jobs: DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip" - echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" curl -L -o artifact.zip \ - "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" mkdir dawn unzip artifact.zip - tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1 + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -1704,6 +1704,34 @@ jobs: run: | GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + ggml-ci-mac-webgpu: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dawn Dependency + id: dawn-depends + run: | + DAWN_VERSION="v2.0.0" + DAWN_OWNER="reeselevine" + DAWN_REPO="dawn" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" + curl -L -o artifact.zip \ + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip" + mkdir dawn + unzip artifact.zip + tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \ + bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + ggml-ci-mac-vulkan: runs-on: [self-hosted, macOS, ARM64] diff --git a/ci/run.sh b/ci/run.sh index 5c2d325a56..3deebd5dd3 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -105,7 +105,20 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then fi if [ ! -z ${GG_BUILD_WEBGPU} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1 -DGGML_METAL=OFF -DGGML_BLAS=OFF" + + if [ ! -z "${GG_BUILD_WEBGPU_DAWN_PREFIX}" ]; then + if [ -z "${CMAKE_PREFIX_PATH}" ]; then + export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}" + else + export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}:${CMAKE_PREFIX_PATH}" + fi + fi + + # For some systems, Dawn_DIR needs to be set explicitly, e.g., the lib64 path + if [ ! -z "${GG_BUILD_WEBGPU_DAWN_DIR}" ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DDawn_DIR=${GG_BUILD_WEBGPU_DAWN_DIR}" + fi fi if [ ! -z ${GG_BUILD_MUSA} ]; then diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp new file mode 100644 index 0000000000..7fdb4c8c8d --- /dev/null +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -0,0 +1,169 @@ +#ifndef GGML_WEBGPU_SHADER_LIB_HPP +#define GGML_WEBGPU_SHADER_LIB_HPP + +#include "ggml.h" +#include "pre_wgsl.hpp" + +#include +#include + +#define GGML_WEBGPU_F16_SIZE_BYTES 2 +#define GGML_WEBGPU_F32_SIZE_BYTES 4 +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u +// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. +#define GGML_WEBGPU_KV_SEQ_PAD 256u + +struct ggml_webgpu_flash_attn_shader_lib_context { + ggml_type kv_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool kv_direct; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; + size_t wg_mem_limit_bytes; + uint32_t max_subgroup_size; +}; + +struct ggml_webgpu_flash_attn_shader_decisions { + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + ggml_webgpu_flash_attn_shader_decisions decisions; +}; + +// This is exposed because it's necessary in supports_op +inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, + uint32_t kv_tile, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); + size_t f16_elems = 0; + size_t f32_elems = 0; + f16_elems += q_tile * head_dim_qk; // q_shmem + if (!kv_direct) { + f16_elems += kv_tile * max_head_dim; // kv_shmem + } + f16_elems += q_tile * head_dim_v; // o_shmem + if (has_mask) { + f16_elems += q_tile * kv_tile; // mask_shmem + } + f16_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile; // row_max_shmem + f32_elems += q_tile; // exp_sum_shmem + return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; +} + +static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.kv_direct) { + bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v); + } + if (context.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn"; + + switch (context.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(context.kv_type); + + if (context.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (context.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (context.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + + if (context.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.head_dim_v); + + // For now these are not part of the variant name + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + // Add chosen Q/KV tile sizes + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (context.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + // Avoids having to use bounds-checks and decreasing performance for direct KV loads + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + // workgroup size + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + result.decisions.q_tile = q_tile; + result.decisions.kv_tile = kv_tile; + result.decisions.wg_size = wg_size; + return result; +} + +#endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c7afdfb8e9..f64f94b96f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -7,7 +7,9 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" +#include "ggml-webgpu-shader-lib.hpp" #include "ggml-wgsl-shaders.hpp" +#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -30,7 +32,7 @@ #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl -# define WEBGPU_DEBUG_BUF_ELEMS 32 +# define WEBGPU_DEBUG_BUF_ELEMS 512 #else # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG @@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool { struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; + void * context = nullptr; }; struct webgpu_command { @@ -263,6 +266,46 @@ struct webgpu_command { #endif }; +struct flash_attn_pipeline_key { + int q_type; + int kv_type; + int dst_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool kv_direct; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + + bool operator==(const flash_attn_pipeline_key & other) const { + return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && + has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +// Same hash combine function as in boost +template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { + seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +struct flash_attn_pipeline_key_hash { + size_t operator()(const flash_attn_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + return seed; + } +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -271,12 +314,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; - uint32_t subgroup_size; + uint32_t max_subgroup_size; -#ifndef __EMSCRIPTEN__ - bool supports_subgroup_matrix = false; - wgpu::SubgroupMatrixConfig subgroup_matrix_config; -#endif + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; std::recursive_mutex mutex; std::atomic_uint inflight_threads = 0; @@ -284,20 +327,24 @@ struct webgpu_context_struct { webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; + pre_wgsl::Preprocessor p; + std::map memset_pipelines; // variant or type index std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - std::map> set_rows_pipelines; // dst_type, vectorized - std::map> get_rows_pipelines; // src_type, vectorized + std::unordered_map flash_attn_pipelines; - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context { label(std::move(lbl)) {} }; -/* End struct definitions */ - /* WebGPU object initializations */ // Process a WGSL shader string, replacing tokens of the form {{KEY}} with @@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); - const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange(); - std::cout << "debug data:"; - for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) { - std::cout << " " << i << ": " << debug_data[i]; - } - std::cout << "\n"; + const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); + std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); } #endif @@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { return ctx->name.c_str(); } +// TODO: implement proper cleanup static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); @@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { return ctx->buffer; } -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) { +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); } -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) { +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); } @@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, #ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; @@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + float scale = *(float *) dst->op_params; + float max_bias; + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + float logit_softcap; + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) Q->ne[2], // number of heads + (uint32_t) Q->ne[1], // sequence length (Q) + (uint32_t) K->ne[1], // sequence length (K/V) + (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 + (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 + (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 + (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 + (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 + (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 + (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 + (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 + (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 + has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) + *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) + *(uint32_t *) &max_bias, + *(uint32_t *) &logit_softcap, + *(uint32_t *) &n_head_log2, + *(uint32_t *) &m0, + *(uint32_t *) &m1 + + }; + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) } + }; + uint32_t binding_index = 3; + if (has_mask) { + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + bool kv_direct = + (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + flash_attn_pipeline_key key = { + .q_type = Q->type, + .kv_type = K->type, + .dst_type = dst->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + }; + + webgpu_pipeline pipeline; + ggml_webgpu_flash_attn_shader_decisions decisions = {}; + + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + decisions = *static_cast(pipeline.context); + } else { + std::lock_guard lock(ctx->mutex); + it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + decisions = *static_cast(pipeline.context); + } else { + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .sg_mat_m = ctx->sg_mat_m, + .sg_mat_n = ctx->sg_mat_n, + .sg_mat_k = ctx->sg_mat_k, + .wg_mem_limit_bytes = + ctx->limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->max_subgroup_size }; + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); + ctx->flash_attn_pipelines.emplace(key, pipeline); + decisions = processed.decisions; + } + } + + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); ggml_unary_op unary_op = ggml_get_unary_op(dst); @@ -1397,6 +1576,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); @@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); futures.push_back(new_futures); } + ggml_backend_webgpu_wait(ctx, futures); ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); @@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); proc_mul_mat_f32_f32_vec = @@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); } +// TODO: move most initialization logic here static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); @@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_FLASH_ATTN_EXT: + { + if (!webgpu_ctx->supports_subgroup_matrix) { + break; + } + // Head dimensions must fit in workgroup memory with minimum tile sizes + size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], + has_mask, kv_direct); + if (min_bytes > limit_bytes) { + break; + } + + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + src2->type == src1->type && op->type == GGML_TYPE_F32; + break; + } case GGML_OP_RMS_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; @@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { } // TODO: Does this need to be thread safe? Is it only called once? +// TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); @@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->subgroup_matrix_config = config; + ctx->sg_mat_m = config.M; + ctx->sg_mat_n = config.N; + ctx->sg_mat_k = config.K; valid_subgroup_matrix_config = true; break; } @@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t #endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; + ctx->max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; @@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); + GGML_UNUSED(reason); + GGML_UNUSED(message); + //TODO: uncomment once proper free logic is in place + //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + //std::string(message).c_str()); }); dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp new file mode 100644 index 0000000000..4d4359463c --- /dev/null +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -0,0 +1,778 @@ +#ifndef PRE_WGSL_HPP +#define PRE_WGSL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace pre_wgsl { + +//============================================================== +// Options +//============================================================== +struct Options { + std::string include_path = "."; + std::vector macros; +}; + +//============================================================== +// Utility: trim +//============================================================== +static std::string trim(const std::string & s) { + size_t a = 0; + while (a < s.size() && std::isspace((unsigned char) s[a])) { + a++; + } + size_t b = s.size(); + while (b > a && std::isspace((unsigned char) s[b - 1])) { + b--; + } + return s.substr(a, b - a); +} + +static std::string trim_value(std::istream & is) { + std::string str; + std::getline(is, str); + return trim(str); +} + +static bool isIdentChar(char c) { + return std::isalnum(static_cast(c)) || c == '_'; +} + +static std::string expandMacrosRecursiveInternal(const std::string & line, + const std::unordered_map & macros, + std::unordered_set & visiting); + +static std::string expandMacroValue(const std::string & name, + const std::unordered_map & macros, + std::unordered_set & visiting) { + if (visiting.count(name)) { + throw std::runtime_error("Recursive macro: " + name); + } + visiting.insert(name); + + auto it = macros.find(name); + if (it == macros.end()) { + visiting.erase(name); + return name; + } + + const std::string & value = it->second; + if (value.empty()) { + visiting.erase(name); + return ""; + } + + std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); + visiting.erase(name); + return expanded; +} + +static std::string expandMacrosRecursiveInternal(const std::string & line, + const std::unordered_map & macros, + std::unordered_set & visiting) { + std::string result; + result.reserve(line.size()); + + size_t i = 0; + while (i < line.size()) { + if (isIdentChar(line[i])) { + size_t start = i; + while (i < line.size() && isIdentChar(line[i])) { + i++; + } + std::string token = line.substr(start, i - start); + + auto it = macros.find(token); + if (it != macros.end()) { + result += expandMacroValue(token, macros, visiting); + } else { + result += token; + } + } else { + result += line[i]; + i++; + } + } + + return result; +} + +static std::string expandMacrosRecursive(const std::string & line, + const std::unordered_map & macros) { + std::unordered_set visiting; + return expandMacrosRecursiveInternal(line, macros, visiting); +} + +//============================================================== +// Tokenizer for expressions in #if/#elif +//============================================================== +class ExprLexer { + public: + enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; + + struct Tok { + Kind kind; + std::string text; + }; + + explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} + + Tok next() { + skipWS(); + if (pos >= src.size()) { + return { END, "" }; + } + + char c = src[pos]; + + // number + if (std::isdigit((unsigned char) c)) { + size_t start = pos; + while (pos < src.size() && std::isdigit((unsigned char) src[pos])) { + pos++; + } + return { NUMBER, std::string(src.substr(start, pos - start)) }; + } + + // identifier + if (std::isalpha((unsigned char) c) || c == '_') { + size_t start = pos; + while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) { + pos++; + } + return { IDENT, std::string(src.substr(start, pos - start)) }; + } + + if (c == '(') { + pos++; + return { LPAREN, "(" }; + } + if (c == ')') { + pos++; + return { RPAREN, ")" }; + } + + // multi-char operators + static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" }; + for (auto op : two_ops) { + if (src.substr(pos, 2) == op) { + pos += 2; + return { OP, std::string(op) }; + } + } + + // single-char operators + if (std::string("+-*/%<>!").find(c) != std::string::npos) { + pos++; + return { OP, std::string(1, c) }; + } + + // unexpected + pos++; + return { END, "" }; + } + + private: + std::string_view src; + size_t pos; + + void skipWS() { + while (pos < src.size() && std::isspace((unsigned char) src[pos])) { + pos++; + } + } +}; + +//============================================================== +// Expression Parser (recursive descent) +//============================================================== +class ExprParser { + public: + ExprParser(std::string_view expr, + const std::unordered_map & macros, + std::unordered_set & visiting) : + lex(expr), + macros(macros), + visiting(visiting) { + advance(); + } + + int parse() { return parseLogicalOr(); } + + private: + ExprLexer lex; + ExprLexer::Tok tok; + const std::unordered_map & macros; + std::unordered_set & visiting; + + void advance() { tok = lex.next(); } + + bool acceptOp(const std::string & s) { + if (tok.kind == ExprLexer::OP && tok.text == s) { + advance(); + return true; + } + return false; + } + + bool acceptKind(ExprLexer::Kind k) { + if (tok.kind == k) { + advance(); + return true; + } + return false; + } + + int parseLogicalOr() { + int v = parseLogicalAnd(); + while (acceptOp("||")) { + int rhs = parseLogicalAnd(); + v = (v || rhs); + } + return v; + } + + int parseLogicalAnd() { + int v = parseEquality(); + while (acceptOp("&&")) { + int rhs = parseEquality(); + v = (v && rhs); + } + return v; + } + + int parseEquality() { + int v = parseRelational(); + for (;;) { + if (acceptOp("==")) { + int rhs = parseRelational(); + v = (v == rhs); + } else if (acceptOp("!=")) { + int rhs = parseRelational(); + v = (v != rhs); + } else { + break; + } + } + return v; + } + + int parseRelational() { + int v = parseShift(); + for (;;) { + if (acceptOp("<")) { + int rhs = parseShift(); + v = (v < rhs); + } else if (acceptOp(">")) { + int rhs = parseShift(); + v = (v > rhs); + } else if (acceptOp("<=")) { + int rhs = parseShift(); + v = (v <= rhs); + } else if (acceptOp(">=")) { + int rhs = parseShift(); + v = (v >= rhs); + } else { + break; + } + } + return v; + } + + int parseShift() { + int v = parseAdd(); + for (;;) { + if (acceptOp("<<")) { + int rhs = parseAdd(); + v = (v << rhs); + } else if (acceptOp(">>")) { + int rhs = parseAdd(); + v = (v >> rhs); + } else { + break; + } + } + return v; + } + + int parseAdd() { + int v = parseMult(); + for (;;) { + if (acceptOp("+")) { + int rhs = parseMult(); + v = (v + rhs); + } else if (acceptOp("-")) { + int rhs = parseMult(); + v = (v - rhs); + } else { + break; + } + } + return v; + } + + int parseMult() { + int v = parseUnary(); + for (;;) { + if (acceptOp("*")) { + int rhs = parseUnary(); + v = (v * rhs); + } else if (acceptOp("/")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v / rhs); + } else if (acceptOp("%")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v % rhs); + } else { + break; + } + } + return v; + } + + int parseUnary() { + if (acceptOp("!")) { + return !parseUnary(); + } + if (acceptOp("-")) { + return -parseUnary(); + } + if (acceptOp("+")) { + return +parseUnary(); + } + return parsePrimary(); + } + + int parsePrimary() { + // '(' expr ')' + if (acceptKind(ExprLexer::LPAREN)) { + int v = parse(); + if (!acceptKind(ExprLexer::RPAREN)) { + throw std::runtime_error("missing ')'"); + } + return v; + } + + // number + if (tok.kind == ExprLexer::NUMBER) { + int v = std::stoi(tok.text); + advance(); + return v; + } + + // defined(identifier) + if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { + advance(); + if (acceptKind(ExprLexer::LPAREN)) { + if (tok.kind != ExprLexer::IDENT) { + throw std::runtime_error("expected identifier in defined()"); + } + std::string name = tok.text; + advance(); + if (!acceptKind(ExprLexer::RPAREN)) { + throw std::runtime_error("missing ) in defined()"); + } + return macros.count(name) ? 1 : 0; + } else { + // defined NAME + if (tok.kind != ExprLexer::IDENT) { + throw std::runtime_error("expected identifier in defined NAME"); + } + std::string name = tok.text; + advance(); + return macros.count(name) ? 1 : 0; + } + } + + // identifier -> treat as integer, if defined use its value else 0 + if (tok.kind == ExprLexer::IDENT) { + std::string name = tok.text; + advance(); + auto it = macros.find(name); + if (it == macros.end()) { + return 0; + } + if (it->second.empty()) { + return 1; + } + return evalMacroExpression(name, it->second); + } + + // unexpected + return 0; + } + + int evalMacroExpression(const std::string & name, const std::string & value) { + if (visiting.count(name)) { + throw std::runtime_error("Recursive macro: " + name); + } + + visiting.insert(name); + ExprParser ep(value, macros, visiting); + int v = ep.parse(); + visiting.erase(name); + return v; + } +}; + +//============================================================== +// Preprocessor +//============================================================== +class Preprocessor { + public: + explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { + // Treat empty include path as current directory + if (opts_.include_path.empty()) { + opts_.include_path = "."; + } + parseMacroDefinitions(opts_.macros); + } + + std::string preprocess_file(const std::string & filename, const std::vector & additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess(const std::string & contents, const std::vector & additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess_includes_file(const std::string & filename) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly); + return result; + } + + std::string preprocess_includes(const std::string & contents) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly); + return result; + } + + private: + Options opts_; + std::unordered_map global_macros; + + enum class DirectiveMode { All, IncludesOnly }; + + struct Cond { + bool parent_active; + bool active; + bool taken; + }; + + //---------------------------------------------------------- + // Parse macro definitions into global_macros + //---------------------------------------------------------- + void parseMacroDefinitions(const std::vector & macro_defs) { + for (const auto & def : macro_defs) { + size_t eq_pos = def.find('='); + if (eq_pos != std::string::npos) { + // Format: NAME=VALUE + std::string name = trim(def.substr(0, eq_pos)); + std::string value = trim(def.substr(eq_pos + 1)); + global_macros[name] = value; + } else { + // Format: NAME + std::string name = trim(def); + global_macros[name] = ""; + } + } + } + + //---------------------------------------------------------- + // Build combined macro map and predefined set for a preprocessing operation + //---------------------------------------------------------- + void buildMacros(const std::vector & additional_macros, + std::unordered_map & macros, + std::unordered_set & predefined) { + macros = global_macros; + predefined.clear(); + + for (const auto & [name, value] : global_macros) { + predefined.insert(name); + } + + for (const auto & def : additional_macros) { + size_t eq_pos = def.find('='); + std::string name, value; + if (eq_pos != std::string::npos) { + name = trim(def.substr(0, eq_pos)); + value = trim(def.substr(eq_pos + 1)); + } else { + name = trim(def); + value = ""; + } + + // Add to macros map (will override global if same name) + macros[name] = value; + predefined.insert(name); + } + } + + //---------------------------------------------------------- + // Helpers + //---------------------------------------------------------- + std::string loadFile(const std::string & fname) { + std::ifstream f(fname); + if (!f.is_open()) { + throw std::runtime_error("Could not open file: " + fname); + } + std::stringstream ss; + ss << f.rdbuf(); + return ss.str(); + } + + bool condActive(const std::vector & cond) const { + if (cond.empty()) { + return true; + } + return cond.back().active; + } + + //---------------------------------------------------------- + // Process a file + //---------------------------------------------------------- + std::string processFile(const std::string & name, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + if (include_stack.count(name)) { + throw std::runtime_error("Recursive include: " + name); + } + + include_stack.insert(name); + std::string shader_code = loadFile(name); + std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode); + include_stack.erase(name); + return out; + } + + std::string processIncludeFile(const std::string & fname, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + std::string full_path = opts_.include_path + "/" + fname; + return processFile(full_path, macros, predefined_macros, include_stack, mode); + } + + //---------------------------------------------------------- + // Process text + //---------------------------------------------------------- + std::string processString(const std::string & shader_code, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + std::vector cond; // Conditional stack for this shader + std::stringstream out; + std::istringstream in(shader_code); + std::string line; + + while (std::getline(in, line)) { + std::string t = trim(line); + + if (!t.empty() && t[0] == '#') { + bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); + if (mode == DirectiveMode::IncludesOnly && !handled) { + out << line << "\n"; + } + } else { + if (mode == DirectiveMode::IncludesOnly) { + out << line << "\n"; + } else if (condActive(cond)) { + // Expand macros in the line before outputting + std::string expanded = expandMacrosRecursive(line, macros); + out << expanded << "\n"; + } + } + } + + if (mode == DirectiveMode::All && !cond.empty()) { + throw std::runtime_error("Unclosed #if directive"); + } + + return out.str(); + } + + //---------------------------------------------------------- + // Directive handler + //---------------------------------------------------------- + bool handleDirective(const std::string & t, + std::stringstream & out, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::vector & cond, + std::unordered_set & include_stack, + DirectiveMode mode) { + // split into tokens + std::string body = t.substr(1); + std::istringstream iss(body); + std::string cmd; + iss >> cmd; + + if (cmd == "include") { + if (mode == DirectiveMode::All && !condActive(cond)) { + return true; + } + std::string file; + iss >> file; + if (file.size() >= 2 && file.front() == '"' && file.back() == '"') { + file = file.substr(1, file.size() - 2); + } + out << processIncludeFile(file, macros, predefined_macros, include_stack, mode); + return true; + } + + if (mode == DirectiveMode::IncludesOnly) { + return false; + } + + if (cmd == "define") { + if (!condActive(cond)) { + return true; + } + std::string name; + iss >> name; + // Don't override predefined macros from options + if (predefined_macros.count(name)) { + return true; + } + std::string value = trim_value(iss); + macros[name] = value; + return true; + } + + if (cmd == "undef") { + if (!condActive(cond)) { + return true; + } + std::string name; + iss >> name; + // Don't undef predefined macros from options + if (predefined_macros.count(name)) { + return true; + } + macros.erase(name); + return true; + } + + if (cmd == "ifdef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = macros.count(name); + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "ifndef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = !macros.count(name); + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "if") { + std::string expr = trim_value(iss); + bool p = condActive(cond); + bool v = false; + if (p) { + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + v = ep.parse() != 0; + } + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "elif") { + std::string expr = trim_value(iss); + + if (cond.empty()) { + throw std::runtime_error("#elif without #if"); + } + + Cond & c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + + if (c.taken) { + c.active = false; + return true; + } + + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + bool v = ep.parse() != 0; + c.active = v; + if (v) { + c.taken = true; + } + return true; + } + + if (cmd == "else") { + if (cond.empty()) { + throw std::runtime_error("#else without #if"); + } + + Cond & c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + if (c.taken) { + c.active = false; + } else { + c.active = true; + c.taken = true; + } + return true; + } + + if (cmd == "endif") { + if (cond.empty()) { + throw std::runtime_error("#endif without #if"); + } + cond.pop_back(); + return true; + } + + // Unknown directive + throw std::runtime_error("Unknown directive: #" + cmd); + } +}; + +} // namespace pre_wgsl + +#endif // PRE_WGSL_HPP diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl new file mode 100644 index 0000000000..de7c132a62 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -0,0 +1,591 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +// Default values +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +// Each workgroup processes one subgroup matrix of Q rows +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 + +// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +// Quantization constants/helpers +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +// number of quantized elements processed per thread +#if defined(KV_Q4_0) +#define NQ 16 +// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +// Ok not to put these in a define block, compiler will remove if unused +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array; + +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif + +@group(0) @binding(DST_BINDING) var dst: array; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +// The number of Q rows processed per workgroup +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; // output shmem + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// storage for output of Q*K^T scores for online softmax (S matrix from paper) +// also storage for diagonal matrix during online softmax (P matrix from paper) +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + let mask_term = slope * mask_val; + v += mask_term; +#endif + return v; +} + + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + // batch index + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + // clear inter_shmem to ensure zero-initialized accumulators + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } + + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + // TODO: this loop seems to be the current largest bottleneck + for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { + let inter_offset = kv_block * SG_MAT_N; + var acc: subgroup_matrix_result = subgroupMatrixLoad< + subgroup_matrix_result>(&inter_shmem, inter_offset, false, KV_TILE); +#ifdef KV_DIRECT + let k_block_row = kv_tile + kv_block * SG_MAT_N; + let k_global_offset = k_head_offset + k_block_row * params.stride_k1; +#else + let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK; +#endif + for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) { + // load q submatrix from shared memory + var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &q_shmem, + head_dim_block, + false, + HEAD_DIM_QK + ); + + // load k submatrix from device or shared memory +#ifdef KV_DIRECT + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &K, + k_global_offset + head_dim_block, + true, + params.stride_k1 + ); +#else + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &kv_shmem, + k_block_offset + head_dim_block, + true, + HEAD_DIM_QK + ); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); + } + + // store acc to shared memory for softmax (S matrix from paper) + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + } + +#ifdef MASK + // load mask tile into shared memory for this KV block + // TODO: optimize and skip if mask is -INF for the entire tile + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } +#endif + + workgroupBarrier(); + + // online softmax + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // initialize running max for this row + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + + // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +#endif + + workgroupBarrier(); + + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + for (var head_dim_block = subgroup_id * SG_MAT_N; + head_dim_block < HEAD_DIM_V; + head_dim_block += num_subgroups * SG_MAT_N) { + // load O submatrix from shared memory + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + &o_shmem, + head_dim_block, + false, + HEAD_DIM_V + ); + + for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { + let p_offset = kv_block * SG_MAT_N; + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &inter_shmem, + p_offset, + false, + KV_TILE + ); + + // load V submatrix from global or shared memory +#ifdef KV_DIRECT + let v_block_row = kv_tile + kv_block * SG_MAT_N; + let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &V, + v_global_offset, + false, + params.stride_v1 + ); +#else + let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &kv_shmem, + v_block_offset + head_dim_block, + false, + HEAD_DIM_V + ); +#endif + // O += P * V + o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); + } + + // store O back to shared memory + subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V); + } + + workgroupBarrier(); + } + +#ifdef SINKS + // add sinks (applied once after processing all KV tiles) + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + let val = f32(o_shmem[idx]) * max_exp; + o_shmem[idx] = f16(val); + } + } + + workgroupBarrier(); +#endif + + // write output back to global memory + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx]; + let scaled = f32(o_val) * scale; + dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; + } + } +}