#ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP #include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" #include #include #include #include #include #include #define GGML_WEBGPU_F16_SIZE_BYTES 2 #define GGML_WEBGPU_F32_SIZE_BYTES 4 #define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u // Matrix multiplication parameters // Register tiling parameters #define WEBGPU_MUL_MAT_TILE_M 8 #define WEBGPU_MUL_MAT_TILE_N 8 #define WEBGPU_MUL_MAT_WG_SIZE_M 8 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_TILE_K 32 // Subgroup matrix parameters // The number of subgroups in the M dimension #define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension #define WEBGPU_MUL_MAT_SUBGROUP_N 2 // The number of subgroup matrices each subgroup accumulates over #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 // Must be multiple of 4 to work with vectorized paths, and must divide // mul_mat_vec wg size #define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 #define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 // Requires 32 threads per output (wg_size/outputs_per_wg == 32) #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 // Requires at least two (and multiple of 2) k-quant blocks per tile #define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } struct ggml_webgpu_shader_lib_context { ggml_tensor * src0; ggml_tensor * src1; ggml_tensor * src2; ggml_tensor * src3; ggml_tensor * src4; ggml_tensor * dst; uint32_t max_wg_size; size_t wg_mem_limit_bytes = 0; bool inplace = false; bool overlap = false; bool src_overlap = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; uint32_t sg_mat_k = 0; uint32_t max_subgroup_size = 0; }; struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; std::shared_ptr context = nullptr; }; struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; struct ggml_webgpu_processed_shader { std::string wgsl; std::string variant; std::shared_ptr decisions; }; struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t block_size; uint32_t tokens_per_wg; }; /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { uint32_t max_wg_size; size_t wg_mem_limit_bytes; int32_t order; }; /** Set Rows **/ struct ggml_webgpu_set_rows_pipeline_key { int dst_type; int vec4; int i64_idx; bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; } }; struct ggml_webgpu_set_rows_pipeline_key_hash { size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.vec4); ggml_webgpu_hash_combine(seed, key.i64_idx); return seed; } }; struct ggml_webgpu_set_rows_shader_decisions { bool vec4; bool i64_idx; uint32_t wg_size; }; /** Set **/ struct ggml_webgpu_set_pipeline_key { ggml_type type; bool inplace; bool operator==(const ggml_webgpu_set_pipeline_key & other) const { return type == other.type && inplace == other.inplace; } }; struct ggml_webgpu_set_pipeline_key_hash { size_t operator()(const ggml_webgpu_set_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.inplace); return seed; } }; /** Get Rows **/ struct ggml_webgpu_get_rows_pipeline_key { ggml_type src_type; int vectorized; bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const { return src_type == other.src_type && vectorized == other.vectorized; } }; struct ggml_webgpu_get_rows_pipeline_key_hash { size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src_type); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } }; /** Row Norm **/ struct ggml_webgpu_row_norm_pipeline_key { ggml_op op; bool inplace; bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { return op == other.op && inplace == other.inplace; } }; struct ggml_webgpu_row_norm_pipeline_key_hash { size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.inplace); return seed; } }; /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } }; struct ggml_webgpu_pad_pipeline_key_hash { size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.circular); return seed; } }; /** Solve Tri **/ struct ggml_webgpu_solve_tri_pipeline_key { int type; int n; int k; bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const { return type == other.type && n == other.n && k == other.k; } }; struct ggml_webgpu_solve_tri_pipeline_key_hash { size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.n); ggml_webgpu_hash_combine(seed, key.k); return seed; } }; /** SSM Conv **/ struct ggml_webgpu_ssm_conv_pipeline_key { int type; int vectorized; bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const { return type == other.type && vectorized == other.vectorized; } }; /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; int s_v; int kda; bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const { return type == other.type && s_v == other.s_v && kda == other.kda; } }; struct ggml_webgpu_gated_delta_net_pipeline_key_hash { size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.s_v); ggml_webgpu_hash_combine(seed, key.kda); return seed; } }; struct ggml_webgpu_ssm_conv_pipeline_key_hash { size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } }; /** Scale **/ struct ggml_webgpu_scale_pipeline_key { int inplace; bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; } }; struct ggml_webgpu_scale_pipeline_key_hash { size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.inplace); return seed; } }; /** Concat **/ struct ggml_webgpu_concat_pipeline_key { int type; bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } }; struct ggml_webgpu_concat_pipeline_key_hash { size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); return seed; } }; /** Repeat **/ struct ggml_webgpu_repeat_pipeline_key { int type; bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; } }; struct ggml_webgpu_repeat_pipeline_key_hash { size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); return seed; } }; /** Binary **/ struct ggml_webgpu_binary_pipeline_key { int type; int op; bool inplace; bool overlap; bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; struct ggml_webgpu_binary_pipeline_key_hash { size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.overlap); ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; /** Unary **/ struct ggml_webgpu_unary_pipeline_key { int type; int op; bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella bool inplace; ggml_tri_type ttype; // only used for GGML_OP_TRI bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace && ttype == other.ttype; } }; struct ggml_webgpu_unary_pipeline_key_hash { size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.is_unary); ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.ttype); return seed; } }; /** FlashAttention */ struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; bool has_mask; bool has_sinks; bool uses_logit_softcap; bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.kv_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; struct ggml_webgpu_flash_attn_shader_lib_context { ggml_webgpu_flash_attn_pipeline_key key; uint32_t sg_mat_m; uint32_t sg_mat_n; uint32_t sg_mat_k; size_t wg_mem_limit_bytes; uint32_t max_subgroup_size; }; struct ggml_webgpu_flash_attn_shader_decisions { uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; }; inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { // Keep conservative defaults unless this is the f16 vec-split shape family. if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { return 1u; } // Head-dim specializations used by the tuned vec f16 path. switch (key.head_dim_qk) { case 64: return 2u; case 96: return 4u; case 128: return 1u; case 192: return 2u; case 576: return 2u; default: return 1u; } } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { uint32_t head_dim_v; uint32_t wg_size; }; struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.wg_size); return seed; } }; inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; } struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; uint32_t max_wg_size; }; inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( pre_wgsl::Preprocessor & preprocessor, const char * shader_src, const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { std::vector defines; std::string variant = "flash_attn_vec_reduce"; defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); variant += std::string("_wg") + std::to_string(context.max_wg_size); ggml_webgpu_processed_shader result; result.wgsl = preprocessor.preprocess(shader_src, defines); result.variant = variant; return result; } struct ggml_webgpu_flash_attn_blk_pipeline_key { uint32_t q_tile; uint32_t kv_tile; bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return q_tile == other.q_tile && kv_tile == other.kv_tile; } }; struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.q_tile); ggml_webgpu_hash_combine(seed, key.kv_tile); return seed; } }; struct ggml_webgpu_flash_attn_blk_shader_lib_context { ggml_webgpu_flash_attn_blk_pipeline_key key; uint32_t max_wg_size; }; inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( pre_wgsl::Preprocessor & preprocessor, const char * shader_src, const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { std::vector defines; std::string variant = "flash_attn_vec_blk"; defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); variant += std::string("_qt") + std::to_string(context.key.q_tile); defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); variant += std::string("_kvt") + std::to_string(context.key.kv_tile); uint32_t wg_size = 1; while ((wg_size << 1) <= context.max_wg_size) { wg_size <<= 1; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); variant += std::string("_wg") + std::to_string(wg_size); ggml_webgpu_processed_shader result; result.wgsl = preprocessor.preprocess(shader_src, defines); result.variant = variant; return result; } // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; f16_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { f16_elems += kv_tile * max_head_dim; // kv_shmem } f16_elems += q_tile * head_dim_v; // o_shmem if (has_mask) { f16_elems += q_tile * kv_tile; // mask_shmem } f16_elems += q_tile * kv_tile; // inter_shmem f32_elems += q_tile; // row_max_shmem f32_elems += q_tile; // exp_sum_shmem return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { ggml_type src0_type; ggml_type src1_type; bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type; } }; struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); return seed; } }; struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; } }; struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } }; struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; }; struct ggml_webgpu_mul_mat_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; int use_subgroup_matrix; bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && use_subgroup_matrix == other.use_subgroup_matrix; } }; struct ggml_webgpu_mul_mat_pipeline_key_hash { size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); return seed; } }; struct ggml_webgpu_mul_mat_shader_decisions { uint32_t tile_k; uint32_t wg_size_m; uint32_t wg_size_n; uint32_t wg_size; uint32_t outputs_per_wg; int use_subgroup_matrix; uint32_t tile_m; uint32_t tile_n; // Subgroup matrix parameters uint32_t subgroup_m; uint32_t subgroup_n; uint32_t subgroup_matrix_m; uint32_t subgroup_matrix_n; uint32_t mul_mat_wg_size; }; /** Cpy **/ struct ggml_webgpu_cpy_pipeline_key { ggml_type src_type; ggml_type dst_type; bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { return src_type == other.src_type && dst_type == other.dst_type; } }; struct ggml_webgpu_cpy_pipeline_key_hash { size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src_type); ggml_webgpu_hash_combine(seed, key.dst_type); return seed; } }; /** Glu **/ struct ggml_webgpu_glu_pipeline_key { ggml_glu_op glu_op; ggml_type type; bool split; bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { return glu_op == other.glu_op && type == other.type && split == other.split; } }; struct ggml_webgpu_glu_pipeline_key_hash { size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.glu_op); ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.split); return seed; } }; /** Rope **/ struct ggml_webgpu_rope_pipeline_key { ggml_type type; bool inplace; bool has_ff; bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { return type == other.type && inplace == other.inplace && has_ff == other.has_ff; } }; struct ggml_webgpu_rope_pipeline_key_hash { size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.has_ff); return seed; } }; /** SoftMax **/ struct ggml_webgpu_soft_max_pipeline_key { ggml_type mask_type; bool has_mask; bool has_sink; bool inplace; bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && inplace == other.inplace; } }; struct ggml_webgpu_soft_max_pipeline_key_hash { size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.mask_type); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sink); ggml_webgpu_hash_combine(seed, key.inplace); return seed; } }; class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet std::unordered_map argmax_pipelines; // key is vec4 std::unordered_map argsort_pipelines; // key is order std::unordered_map argsort_merge_pipelines; // key is order std::unordered_map cumsum_pipelines; // key is fixed, no variants yet std::unordered_map row_norm_pipelines; // op/inplace std::unordered_map get_rows_pipelines; // src_type, vectorized std::unordered_map unary_pipelines; // type/op/inplace std::unordered_map scale_pipelines; // inplace std::unordered_map solve_tri_pipelines; // type std::unordered_map ssm_conv_pipelines; // type/vectorized std::unordered_map gated_delta_net_pipelines; // type/S_v/kda std::unordered_map pad_pipelines; // circular/non-circular std::unordered_map binary_pipelines; // type/op/inplace/overlap std::unordered_map concat_pipelines; // type std::unordered_map repeat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_map flash_attn_vec_reduce_pipelines; std::unordered_map flash_attn_blk_pipelines; std::unordered_map mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) std::unordered_map set_rows_pipelines; std::unordered_map set_pipelines; std::unordered_map cpy_pipelines; std::unordered_map glu_pipelines; std::unordered_map rope_pipelines; std::unordered_map soft_max_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = sum_rows_pipelines.find(1); if (it != sum_rows_pipelines.end()) { return it->second; } std::vector defines; defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_sum_rows, defines); sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows"); return sum_rows_pipelines[1]; } webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = { .op = context.dst->op, .inplace = context.inplace, }; auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { return it->second; } std::vector defines; std::string variant; switch (key.op) { case GGML_OP_RMS_NORM: defines.push_back("RMS_NORM"); variant = "rms_norm"; break; case GGML_OP_L2_NORM: defines.push_back("L2_NORM"); variant = "l2_norm"; break; default: GGML_ABORT("Unsupported op for row_norm shader"); } if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); auto processed = preprocessor.preprocess(wgsl_row_norm, defines); row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); return row_norm_pipelines[key]; } webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { bool vec4 = context.src0->ne[0] % 4 == 0; auto it = argmax_pipelines.find(vec4); if (it != argmax_pipelines.end()) { return it->second; } std::string variant = "argmax"; std::vector defines; defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); if (vec4) { defines.push_back("VEC4"); variant += "_vec4"; } auto processed = preprocessor.preprocess(wgsl_argmax, defines); argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant); return argmax_pipelines.at(vec4); } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, .vec4 = context.src0->ne[0] % 4 == 0, .i64_idx = context.src1->type == GGML_TYPE_I64 }; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "set_rows"; switch (context.dst->type) { case GGML_TYPE_F32: defines.push_back("DST_F32"); variant += "_dstf32"; break; case GGML_TYPE_F16: defines.push_back("DST_F16"); variant += "_dstf16"; break; default: GGML_ABORT("Unsupported dst type for set_rows shader"); } if (key.vec4) { defines.push_back("VEC4"); variant += "_vec4"; } if (key.i64_idx) { defines.push_back("I64_IDX"); variant += "_i64idx"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_set_rows, defines); auto decisions = std::make_shared(); decisions->vec4 = key.vec4; decisions->i64_idx = key.i64_idx; decisions->wg_size = context.max_wg_size; set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); set_rows_pipelines[key].context = decisions; return set_rows_pipelines[key]; } webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "set"; switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_I32: defines.push_back("TYPE_I32"); variant += "_i32"; break; default: GGML_ABORT("Unsupported type for set shader"); } if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_set, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; set_pipelines[key] = pipeline; return set_pipelines[key]; } webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = cumsum_pipelines.find(1); if (it != cumsum_pipelines.end()) { return it->second; } std::vector defines; defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_cumsum, defines); cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum"); return cumsum_pipelines[1]; } webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) { bool is_top_k = context.dst->op == GGML_OP_TOP_K; // ascending order is 0, descending order is 1 const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); auto it = argsort_pipelines.find(order); if (it != argsort_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "argsort"; defines.push_back(std::string("ORDER=") + std::to_string(order)); variant += std::string("_order") + std::to_string(order); uint32_t wg_size = 1; while (wg_size * 2 <= context.max_wg_size && wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { wg_size *= 2; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); auto processed = preprocessor.preprocess(wgsl_argsort, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); argsort_pipelines[order].context = decisions; return argsort_pipelines[order]; } webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) { bool is_top_k = context.dst->op == GGML_OP_TOP_K; // ascending order is 0, descending order is 1 const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); auto it = argsort_merge_pipelines.find(order); if (it != argsort_merge_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "argsort_merge"; defines.push_back(std::string("ORDER=") + std::to_string(order)); variant += std::string("_order") + std::to_string(order); uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines); argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); return argsort_merge_pipelines[order]; } webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; ggml_webgpu_get_rows_pipeline_key key = { .src_type = context.src0->type, .vectorized = (int) vectorized, }; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "get_rows"; const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type); const char * type_str = type_traits->type_name; switch (key.src_type) { case GGML_TYPE_F32: defines.push_back("FLOAT_PARALLEL"); if (key.vectorized) { defines.push_back("F32_VEC"); defines.push_back("SRC_TYPE=vec4"); defines.push_back("DST_TYPE=vec4"); defines.push_back("BLOCK_SIZE=4u"); } else { defines.push_back("F32"); defines.push_back("SRC_TYPE=f32"); defines.push_back("DST_TYPE=f32"); defines.push_back("BLOCK_SIZE=1u"); } variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("FLOAT_PARALLEL"); defines.push_back("F16"); defines.push_back("SRC_TYPE=f16"); defines.push_back("DST_TYPE=f32"); defines.push_back("BLOCK_SIZE=1u"); variant += "_f16"; break; case GGML_TYPE_I32: defines.push_back("FLOAT_PARALLEL"); defines.push_back("I32"); defines.push_back("SRC_TYPE=i32"); defines.push_back("DST_TYPE=i32"); defines.push_back("BLOCK_SIZE=1u"); variant += "_i32"; break; default: { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); defines.push_back(type_upper + "_SCALE_MIN"); defines.push_back(type_upper + "_TABLES"); defines.push_back(type_upper + "_GRID"); variant += "_"; variant += type_str; defines.push_back(std::string("SRC_TYPE=") + type_str); defines.push_back("DST_TYPE=f32"); if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); } else { defines.push_back("BLOCK_SIZE=1u"); } break; } } if (key.vectorized) { variant += "_vec"; } defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_get_rows, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; get_rows_pipelines[key] = pipeline; return get_rows_pipelines[key]; } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "scale"; if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_scale, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; scale_pipelines[key] = pipeline; return scale_pipelines[key]; } webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_solve_tri_pipeline_key key = { .type = context.dst->type, .n = (int) context.src0->ne[0], .k = (int) context.src1->ne[0], }; auto it = solve_tri_pipelines.find(key); if (it != solve_tri_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "solve_tri"; switch (key.type) { case GGML_TYPE_F32: variant += "_f32"; break; default: GGML_ABORT("Unsupported type for solve_tri shader"); } const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size); const uint32_t k_tile = wg_size; const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES; const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row); defines.push_back(std::string("N=") + std::to_string(key.n)); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("K_TILE=") + std::to_string(k_tile)); defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n)); auto processed = preprocessor.preprocess(wgsl_solve_tri, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; solve_tri_pipelines[key] = pipeline; return solve_tri_pipelines[key]; } webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_ssm_conv_pipeline_key key = { .type = context.dst->type, .vectorized = context.src1->ne[0] == 4, }; auto it = ssm_conv_pipelines.find(key); if (it != ssm_conv_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "ssm_conv"; switch (key.type) { case GGML_TYPE_F32: variant += "_f32"; break; default: GGML_ABORT("Unsupported type for ssm_conv shader"); } if (key.vectorized) { defines.push_back("VECTORIZED"); variant += "_vec4"; } constexpr uint32_t block_size = 32u; constexpr uint32_t tokens_per_wg = 8u; defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u"); defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u"); auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines); auto decisions = std::make_shared(); decisions->block_size = block_size; decisions->tokens_per_wg = tokens_per_wg; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; ssm_conv_pipelines[key] = pipeline; return ssm_conv_pipelines[key]; } webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_gated_delta_net_pipeline_key key = { .type = context.dst->type, .s_v = (int) context.src2->ne[0], .kda = context.src3->ne[0] == context.src2->ne[0], }; auto it = gated_delta_net_pipelines.find(key); if (it != gated_delta_net_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "gated_delta_net"; switch (key.type) { case GGML_TYPE_F32: variant += "_f32"; break; default: GGML_ABORT("Unsupported type for gated_delta_net shader"); } if (key.kda) { defines.push_back("KDA"); variant += "_kda"; } defines.push_back("S_V=" + std::to_string(key.s_v) + "u"); defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u"); auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); gated_delta_net_pipelines[key] = pipeline; return gated_delta_net_pipelines[key]; } webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "pad"; if (key.circular) { defines.push_back("CIRCULAR"); variant += "_circular"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_pad, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; pad_pipelines[key] = pipeline; return pad_pipelines[key]; } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_vec_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0, }; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "mul_mat_vec"; // src0 type (matrix row) switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); defines.push_back("MUL_ACC_FLOAT"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); defines.push_back("MUL_ACC_FLOAT"); variant += "_f16"; break; default: { // Quantized types: use helpers but accumulate in f16 const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); std::string src0_name = src0_traits->type_name; std::string type_upper = src0_name; variant += "_" + src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); defines.push_back("BYTE_HELPERS"); defines.push_back("MUL_ACC_" + type_upper); defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); break; } } // src1 type (vector) switch (context.src1->type) { case GGML_TYPE_F32: defines.push_back("SRC1_INNER_TYPE=f32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC1_INNER_TYPE=f16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); } // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; if (key.src0_type >= GGML_TYPE_Q2_K) { tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; mul_mat_vec_pipelines[key] = pipeline; return mul_mat_vec_pipelines[key]; } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type, .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0, .use_subgroup_matrix = context.supports_subgroup_matrix }; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { return it->second; } const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile; std::vector defines; std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile"; // src1 type switch (context.src1->type) { case GGML_TYPE_F32: defines.push_back("SRC1_INNER_TYPE=f32"); break; case GGML_TYPE_F16: defines.push_back("SRC1_INNER_TYPE=f16"); break; default: GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); } // src0 type const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); const char * src0_name = src0_traits->type_name; switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); defines.push_back("FLOAT"); defines.push_back("MUL_ACC_FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); defines.push_back("FLOAT"); defines.push_back("MUL_ACC_FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f16"; break; default: { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); defines.push_back("BYTE_HELPERS"); defines.push_back("MUL_ACC_" + type_upper); defines.push_back("INIT_SRC0_SHMEM_" + type_upper); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); variant += std::string("_") + src0_name; break; } } // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); // Subgroup matrix specifics if (key.use_subgroup_matrix) { defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); } // variant suffix for src1 type variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); if (key.vectorized) { variant += "_vectorized"; } if (!key.use_subgroup_matrix) { defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); } auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->use_subgroup_matrix = key.use_subgroup_matrix; if (key.use_subgroup_matrix) { decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; decisions->wg_size = context.max_subgroup_size; } else { decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; } webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; mul_mat_fast_pipelines[key] = pipeline; return mul_mat_fast_pipelines[key]; } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, .src1_type = context.src1->type }; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "mul_mat"; switch (context.src1->type) { case GGML_TYPE_F32: defines.push_back("SRC1_TYPE=f32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC1_TYPE=f16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); } const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); const char * src0_name = src0_traits->type_name; switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_TYPE=f32"); defines.push_back("FLOAT"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_TYPE=f16"); defines.push_back("FLOAT"); variant += "_f16"; break; default: { // quantized types std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); defines.push_back(std::string("SRC0_TYPE=") + src0_name); defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); defines.push_back(type_upper + "_SCALE_MIN"); defines.push_back(type_upper + "_TABLES"); defines.push_back(type_upper + "_GRID"); variant += std::string("_") + src0_name; break; } } auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); auto decisions = std::make_shared(); decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; mul_mat_legacy_pipelines[key] = pipeline; return mul_mat_legacy_pipelines[key]; } webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; ggml_webgpu_unary_pipeline_key key = { .type = context.dst->type, .op = op, .is_unary = is_unary, .inplace = context.inplace, .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), }; auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { return it->second; } std::vector defines; std::string variant = key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); defines.push_back(variant); switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("TYPE_F16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported type for unary shader"); } if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } if (op == GGML_OP_TRI) { switch (key.ttype) { case GGML_TRI_TYPE_LOWER: defines.push_back("TRI_TYPE_LOWER"); variant += "_tri_type_lower"; break; case GGML_TRI_TYPE_LOWER_DIAG: defines.push_back("TRI_TYPE_LOWER_DIAG"); variant += "_tri_type_lower_diag"; break; case GGML_TRI_TYPE_UPPER: defines.push_back("TRI_TYPE_UPPER"); variant += "_tri_type_upper"; break; case GGML_TRI_TYPE_UPPER_DIAG: defines.push_back("TRI_TYPE_UPPER_DIAG"); variant += "_tri_upper_diag"; break; default: GGML_ABORT("Unsupported ggml_tri_type for unary shader"); } } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_unary, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; unary_pipelines[key] = pipeline; return unary_pipelines[key]; } webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = { .type = context.dst->type, .op = context.dst->op, .inplace = context.inplace, .overlap = context.overlap, .src_overlap = context.src_overlap, }; auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { return it->second; } std::vector defines; std::string op_name = ggml_op_name((ggml_op) key.op); std::string variant = op_name; defines.push_back(std::string("OP_") + op_name); switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("TYPE_F16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported type for binary shader"); } if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } else if (key.overlap) { defines.push_back("OVERLAP"); variant += "_overlap"; } else if (key.src_overlap) { defines.push_back("SRC_OVERLAP"); variant += "_src_overlap"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_binary, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; binary_pipelines[key] = pipeline; return binary_pipelines[key]; } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_concat_pipeline_key key = { .type = context.dst->type, }; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "concat"; switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_I32: defines.push_back("TYPE_I32"); variant += "_i32"; break; default: GGML_ABORT("Unsupported type for concat shader"); } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_concat, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; concat_pipelines[key] = pipeline; return concat_pipelines[key]; } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_repeat_pipeline_key key = { .type = context.dst->type, }; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "repeat"; switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_I32: defines.push_back("TYPE_I32"); variant += "_i32"; break; case GGML_TYPE_I16: defines.push_back("TYPE_I16"); variant += "_i16"; break; default: GGML_ABORT("Unsupported type for repeat shader"); } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_repeat, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; repeat_pipelines[key] = pipeline; return repeat_pipelines[key]; } webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { auto it = flash_attn_pipelines.find(context.key); if (it != flash_attn_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "flash_attn"; switch (context.key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; case GGML_TYPE_F16: defines.push_back("KV_F16"); break; case GGML_TYPE_Q4_0: defines.push_back("KV_Q4_0"); break; case GGML_TYPE_Q8_0: defines.push_back("KV_Q8_0"); break; default: GGML_ABORT("Unsupported KV type for flash attention shader"); } variant += std::string("_") + ggml_type_name(context.key.kv_type); if (context.key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } if (context.key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } if (context.key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } if (context.key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } if (context.key.has_mask && context.key.use_vec) { defines.push_back("BLK"); variant += "_blk"; } defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); if (context.key.use_vec) { q_tile = 1; kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); } if (context.key.kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); uint32_t wg_size = 0; if (context.key.use_vec) { wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); } else { wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); auto decisions = std::make_shared(); decisions->q_tile = q_tile; decisions->kv_tile = kv_tile; decisions->wg_size = wg_size; pipeline.context = decisions; flash_attn_pipelines[context.key] = pipeline; return flash_attn_pipelines[context.key]; } webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { auto it = flash_attn_blk_pipelines.find(context.key); if (it != flash_attn_blk_pipelines.end()) { return it->second; } ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); flash_attn_blk_pipelines[context.key] = pipeline; return flash_attn_blk_pipelines[context.key]; } webgpu_pipeline get_flash_attn_vec_reduce_pipeline( const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { auto it = flash_attn_vec_reduce_pipelines.find(context.key); if (it != flash_attn_vec_reduce_pipelines.end()) { return it->second; } ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); flash_attn_vec_reduce_pipelines[context.key] = pipeline; return flash_attn_vec_reduce_pipelines[context.key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_cpy_pipeline_key key = { .src_type = context.src0->type, .dst_type = context.dst->type, }; auto it = cpy_pipelines.find(key); if (it != cpy_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "cpy"; switch (key.src_type) { case GGML_TYPE_F32: defines.push_back("SRC_F32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC_F16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported src type for cpy shader"); } switch (key.dst_type) { case GGML_TYPE_F32: defines.push_back("DST_F32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("DST_F16"); variant += "_f16"; break; case GGML_TYPE_I32: defines.push_back("DST_I32"); variant += "_i32"; break; default: GGML_ABORT("Unsupported dst type for cpy shader"); } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_cpy, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; cpy_pipelines[key] = pipeline; return cpy_pipelines[key]; } webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_glu_pipeline_key key = { .glu_op = ggml_get_glu_op(context.dst), .type = context.dst->type, .split = (context.src1 != nullptr), }; auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "glu"; switch (key.glu_op) { case GGML_GLU_OP_REGLU: defines.push_back("OP_REGLU"); variant += "_reglu"; break; case GGML_GLU_OP_GEGLU: defines.push_back("OP_GEGLU"); variant += "_geglu"; break; case GGML_GLU_OP_SWIGLU: defines.push_back("OP_SWIGLU"); variant += "_swiglu"; break; case GGML_GLU_OP_SWIGLU_OAI: defines.push_back("OP_SWIGLU_OAI"); variant += "_swiglu_oai"; break; case GGML_GLU_OP_GEGLU_ERF: defines.push_back("OP_GEGLU_ERF"); variant += "_geglu_erf"; break; case GGML_GLU_OP_GEGLU_QUICK: defines.push_back("OP_GEGLU_QUICK"); variant += "_geglu_quick"; break; default: GGML_ABORT("Unsupported GLU op"); } switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("TYPE_F16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported type for GLU shader"); } if (key.split) { variant += "_split"; } else { defines.push_back("NO_SPLIT"); } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_glu, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; glu_pipelines[key] = pipeline; return glu_pipelines[key]; } webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rope_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace, .has_ff = (context.src2 != nullptr), }; auto it = rope_pipelines.find(key); if (it != rope_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "rope"; switch (key.type) { case GGML_TYPE_F32: defines.push_back("TYPE_F32"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("TYPE_F16"); variant += "_f16"; break; default: GGML_ABORT("Unsupported type for ROPE shader"); } if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } if (key.has_ff) { defines.push_back("FF_FUNC"); variant += "_ff"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_rope, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; rope_pipelines[key] = pipeline; return rope_pipelines[key]; } webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_soft_max_pipeline_key key = { .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, .has_mask = (context.src1 != nullptr), .has_sink = (context.src2 != nullptr), .inplace = context.inplace, }; auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { return it->second; } std::vector defines; std::string variant = "soft_max"; if (key.has_mask) { defines.push_back("HAS_MASK"); switch (key.mask_type) { case GGML_TYPE_F32: defines.push_back("MASK_F32"); variant += "_mask_f32"; break; case GGML_TYPE_F16: defines.push_back("MASK_F16"); variant += "_mask_f16"; break; default: GGML_ABORT("Unsupported type for SOFT_MAX shader"); } } if (key.has_sink) { defines.push_back("HAS_SINK"); variant += "_sink"; } if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_soft_max, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; soft_max_pipelines[key] = pipeline; return soft_max_pipelines[key]; } private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, std::string label) { wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code.c_str(); wgpu::ShaderModuleDescriptor shader_desc; shader_desc.nextInChain = &shader_source; wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); wgpu::ComputePipelineDescriptor pipeline_desc; pipeline_desc.label = label.c_str(); pipeline_desc.compute.module = shader_module; pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { const size_t limit_bytes = context.wg_mem_limit_bytes; const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; if (!context.key.kv_direct) { bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); } if (context.key.has_mask) { bytes_per_kv += q_tile; } bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP