From ae6baf4714068153eab5d737760dec743857c508 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 11 Feb 2026 12:32:43 -0800 Subject: [PATCH] flashattention and matrix multiplication moved to new format --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 1286 +++++++++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 200 +-- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 13 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 13 +- 4 files changed, 893 insertions(+), 619 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f358903476..02c11e60a8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -59,12 +59,20 @@ template inline void ggml_webgpu_hash_combine(size_t & seed, const struct ggml_webgpu_shader_lib_context { ggml_tensor * src0; ggml_tensor * src1; + ggml_tensor * src2; + ggml_tensor * src3; + ggml_tensor * src4; ggml_tensor * dst; uint32_t max_wg_size; - size_t wg_mem_limit_bytes = 0; - bool inplace = 0; - bool overlap = 0; + size_t wg_mem_limit_bytes = 0; + bool inplace = 0; + bool overlap = 0; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t max_subgroup_size = 0; }; struct webgpu_pipeline { @@ -212,6 +220,167 @@ struct ggml_webgpu_unary_pipeline_key_hash { } }; +/** FlashAttention */ + +struct ggml_webgpu_flash_attn_pipeline_key { + ggml_type kv_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool kv_direct; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + + bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { + return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +struct ggml_webgpu_flash_attn_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_shader_lib_context { + ggml_webgpu_flash_attn_pipeline_key key; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; + size_t wg_mem_limit_bytes; + uint32_t max_subgroup_size; +}; + +struct ggml_webgpu_flash_attn_shader_decisions { + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + +// This is exposed because it's necessary in supports_op +inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, + uint32_t kv_tile, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); + size_t f16_elems = 0; + size_t f32_elems = 0; + f16_elems += q_tile * head_dim_qk; // q_shmem + if (!kv_direct) { + f16_elems += kv_tile * max_head_dim; // kv_shmem + } + f16_elems += q_tile * head_dim_v; // o_shmem + if (has_mask) { + f16_elems += q_tile * kv_tile; // mask_shmem + } + f16_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile; // row_max_shmem + f32_elems += q_tile; // exp_sum_shmem + return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; +} + +/** Matrix Multiplication **/ + +struct ggml_webgpu_legacy_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + + bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type; + } +}; + +struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_vec_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + + bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_vec_shader_decisions { + uint32_t wg_size; + uint32_t tile_k; + uint32_t outputs_per_wg; + uint32_t vec_size; +}; + +struct ggml_webgpu_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + int use_subgroup_matrix; + + bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_subgroup_matrix == other.use_subgroup_matrix; + } +}; + +struct ggml_webgpu_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_shader_decisions { + uint32_t tile_k; + uint32_t wg_size_m; + uint32_t wg_size_n; + uint32_t wg_size; + uint32_t outputs_per_wg; + int use_subgroup_matrix; + + uint32_t tile_m; + uint32_t tile_n; + + // Subgroup matrix parameters + uint32_t subgroup_m; + uint32_t subgroup_n; + uint32_t subgroup_matrix_m; + uint32_t subgroup_matrix_n; + + uint32_t mul_mat_wg_size; +}; + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -222,15 +391,25 @@ class ggml_webgpu_shader_lib { std::unordered_map argsort_merge_pipelines; // key is order std::unordered_map cumsum_pipelines; // key is fixed, no variants yet std::unordered_map - get_rows_pipelines; // src_type, vectorized + get_rows_pipelines; // src_type, vectorized std::unordered_map - unary_pipelines; // type/op/inplace + unary_pipelines; // type/op/inplace std::unordered_map - scale_pipelines; // inplace + scale_pipelines; // inplace std::unordered_map - pad_pipelines; // circular/non-circular + pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap + std::unordered_map + flash_attn_pipelines; + std::unordered_map + mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) + std::unordered_map + mul_mat_vec_pipelines; // fast mat-vec (n==1) + std::unordered_map + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) std::unordered_map set_rows_pipelines; @@ -385,9 +564,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; + const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, + .src_type = context.src0->type, .vectorized = (int) vectorized, }; @@ -431,33 +610,34 @@ class ggml_webgpu_shader_lib { defines.push_back("BLOCK_SIZE=1u"); variant += "_i32"; break; - default: { - std::string type_upper = type_str; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + default: + { + std::string type_upper = type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - defines.push_back("BYTE_HELPERS"); - defines.push_back(type_upper + "_T"); - defines.push_back(type_upper); - defines.push_back(type_upper + "_SCALE_MIN"); - defines.push_back(type_upper + "_TABLES"); - defines.push_back(type_upper + "_GRID"); + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); - variant += "_"; - variant += type_str; + variant += "_"; + variant += type_str; - defines.push_back(std::string("SRC_TYPE=") + type_str); - defines.push_back("DST_TYPE=f32"); + defines.push_back(std::string("SRC_TYPE=") + type_str); + defines.push_back("DST_TYPE=f32"); - if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { - defines.push_back("BLOCK_SIZE=32u"); - } else if (key.src_type >= GGML_TYPE_Q2_K) { - defines.push_back("BLOCK_SIZE=256u"); - } else { - defines.push_back("BLOCK_SIZE=1u"); + if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + key.src_type == GGML_TYPE_IQ4_NL) { + defines.push_back("BLOCK_SIZE=32u"); + } else if (key.src_type >= GGML_TYPE_Q2_K) { + defines.push_back("BLOCK_SIZE=256u"); + } else { + defines.push_back("BLOCK_SIZE=1u"); + } + break; } - break; - } } if (key.vectorized) { @@ -466,9 +646,9 @@ class ggml_webgpu_shader_lib { defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_get_rows, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_get_rows, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; get_rows_pipelines[key] = pipeline; @@ -493,9 +673,9 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_scale, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_scale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; scale_pipelines[key] = pipeline; @@ -520,23 +700,319 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_pad, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_pad, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; - pad_pipelines[key] = pipeline; + pad_pipelines[key] = pipeline; return pad_pipelines[key]; } + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_vec_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + }; + + auto it = mul_mat_vec_pipelines.find(key); + if (it != mul_mat_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_vec"; + + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + + // src0 type (matrix row) + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + break; + case GGML_TYPE_F16: + defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + + // For fast path we always dequantize from f16 inside the shader + defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + break; + } + } + + // dst type + defines.push_back(key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + + // vec/scalar controls + if (key.vectorized) { + defines.push_back("VEC"); + defines.push_back("VEC_SIZE=4u"); + } else { + defines.push_back("SCALAR"); + defines.push_back("VEC_SIZE=1u"); + } + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tile_k = tile_k; + decisions->outputs_per_wg = outputs_per_wg; + decisions->vec_size = key.vectorized ? 4 : 1; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_vec_pipelines[key] = pipeline; + return mul_mat_vec_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + .use_subgroup_matrix = context.supports_subgroup_matrix + }; + + auto it = mul_mat_fast_pipelines.find(key); + if (it != mul_mat_fast_pipelines.end()) { + return it->second; + } + + const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile; + std::vector defines; + std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile"; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); + defines.push_back(key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); + defines.push_back(key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // Shared memory type + defines.push_back(key.vectorized ? "SHMEM_TYPE=vec4" : "SHMEM_TYPE=f16"); + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + // Use f16 inside the shader for quantized types + defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + + variant += std::string("_") + src0_name; + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + defines.push_back(key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u"); + defines.push_back(key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR"); + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + + // Subgroup matrix specifics + if (key.use_subgroup_matrix) { + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); + defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); + defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); + defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); + defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } + + if (!key.use_subgroup_matrix) { + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + } + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->use_subgroup_matrix = key.use_subgroup_matrix; + if (key.use_subgroup_matrix) { + decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; + decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; + decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + decisions->wg_size = context.max_subgroup_size; + } else { + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; + } + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_fast_pipelines[key] = pipeline; + return mul_mat_fast_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, + .src1_type = context.src1->type }; + + auto it = mul_mat_legacy_pipelines.find(key); + if (it != mul_mat_legacy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat"; + + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); + } + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_TYPE=f32"); + defines.push_back("FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_TYPE=f16"); + defines.push_back("FLOAT"); + variant += "_f16"; + break; + default: + { + // quantized types + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += std::string("_") + src0_name; + break; + } + } + + auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_legacy_pipelines[key] = pipeline; + return mul_mat_legacy_pipelines[key]; + } + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool is_unary = context.dst->op == GGML_OP_UNARY; - const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, + const bool is_unary = context.dst->op == GGML_OP_UNARY; + const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; + ggml_webgpu_unary_pipeline_key key = { + .type = context.dst->type, + .op = op, + .is_unary = is_unary, + .inplace = context.inplace, }; auto it = unary_pipelines.find(key); @@ -545,8 +1021,8 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : - ggml_op_name((ggml_op) key.op); + std::string variant = + key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); defines.push_back(variant); switch (key.type) { @@ -569,9 +1045,9 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_unary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_unary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; unary_pipelines[key] = pipeline; @@ -620,15 +1096,114 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_binary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; binary_pipelines[key] = pipeline; return binary_pipelines[key]; } + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + + bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % context.sg_mat_n == 0); + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = context.src1->type, + .head_dim_qk = (uint32_t) context.src0->ne[0], + .head_dim_v = (uint32_t) context.src2->ne[0], + .kv_direct = kv_direct, + .has_mask = has_mask, + .has_sinks = has_sinks, + .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, + }; + + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = + std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, + context.wg_mem_limit_bytes, context.max_subgroup_size }), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (key.kv_direct) { + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, @@ -648,6 +1223,25 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } + + static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = + (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.key.kv_direct) { + bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); + } + if (context.key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; + } }; // helper function for replacing {{PLACEHOLDERS}} @@ -668,413 +1262,179 @@ struct ggml_webgpu_processed_shader { std::shared_ptr decisions; }; -/** FlashAttention */ - -struct ggml_webgpu_flash_attn_pipeline_key { - ggml_type kv_type; - uint32_t head_dim_qk; - uint32_t head_dim_v; - bool kv_direct; - bool has_mask; - bool has_sinks; - bool uses_logit_softcap; - - bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; - } -}; - -struct ggml_webgpu_flash_attn_pipeline_key_hash { - size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - return seed; - } -}; - -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; -}; - -struct ggml_webgpu_flash_attn_shader_decisions { - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; -}; - -// This is exposed because it's necessary in supports_op -inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, - uint32_t kv_tile, - uint32_t head_dim_qk, - uint32_t head_dim_v, - bool has_mask, - bool kv_direct) { - const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); - size_t f16_elems = 0; - size_t f32_elems = 0; - f16_elems += q_tile * head_dim_qk; // q_shmem - if (!kv_direct) { - f16_elems += kv_tile * max_head_dim; // kv_shmem - } - f16_elems += q_tile * head_dim_v; // o_shmem - if (has_mask) { - f16_elems += q_tile * kv_tile; // mask_shmem - } - f16_elems += q_tile * kv_tile; // inter_shmem - f32_elems += q_tile; // row_max_shmem - f32_elems += q_tile; // exp_sum_shmem - return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; -} - -static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; -} - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn"; - - switch (context.key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(context.key.kv_type); - - if (context.key.has_mask) { - defines.push_back("MASK"); - variant += "_mask"; - } - if (context.key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (context.key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - - if (context.key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - // For now these are not part of the variant name - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - // Add chosen Q/KV tile sizes - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids having to use bounds-checks and decreasing performance for direct - // KV loads - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } - } - - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - - // workgroup size - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} /** Matrix Multiplication **/ -struct ggml_webgpu_mul_mat_pipeline_key { - ggml_type src0_type; - ggml_type src1_type; - int vectorized; - int is_vec; - int use_subgroup_matrix; - int register_tile; - - bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && - is_vec == other.is_vec && use_subgroup_matrix == other.use_subgroup_matrix && - register_tile == other.register_tile; - } -}; - -struct ggml_webgpu_mul_mat_pipeline_key_hash { - size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.src0_type); - ggml_webgpu_hash_combine(seed, key.src1_type); - ggml_webgpu_hash_combine(seed, key.vectorized); - ggml_webgpu_hash_combine(seed, key.is_vec); - ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); - ggml_webgpu_hash_combine(seed, key.register_tile); - return seed; - } -}; - -struct ggml_webgpu_mul_mat_shader_lib_context { - ggml_webgpu_mul_mat_pipeline_key key; - - // For subgroup matrix paths - uint32_t max_subgroup_size; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; -}; - -struct ggml_webgpu_mul_mat_shader_decisions { - uint32_t tile_k; - uint32_t wg_size_m; - uint32_t wg_size_n; - uint32_t wg_size; - uint32_t outputs_per_wg; - int is_vec; - int use_subgroup_matrix; - - // Add new fields for all the parameters - uint32_t tile_m; - uint32_t tile_n; - - // Subgroup matrix parameters - uint32_t subgroup_m; - uint32_t subgroup_n; - uint32_t subgroup_matrix_m; - uint32_t subgroup_matrix_n; - - uint32_t mul_mat_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_mul_mat_shader_lib_context & context) { - std::vector defines; - std::string variant = "mul_mat"; - - // Determine base variant name based on kernel type - if (context.key.is_vec) { - variant = "mul_mat_vec"; - } else if (context.key.use_subgroup_matrix) { - variant = "mul_mat_subgroup_matrix"; - } else if (context.key.register_tile) { - variant = "mul_mat_reg_tile"; - } - - // Determine src0/src1 type strings - const char * src0_type_str = nullptr; - - bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile; - - // src1 type - switch (context.key.src1_type) { - case GGML_TYPE_F32: - defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); - defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); - break; - case GGML_TYPE_F16: - defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); - defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); - break; - default: - break; - } - - // same for all types - defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4" : "SHMEM_TYPE=f16"); - - // src0 type - const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type); - src0_type_str = src0_type_traits->type_name; - - // for f32 and f16, account for vectorized src0 types - switch (context.key.src0_type) { - case GGML_TYPE_F32: - src0_type_str = context.key.vectorized ? "vec4" : "f32"; - defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); - - defines.push_back("FLOAT"); - defines.push_back("MUL_ACC_FLOAT"); - defines.push_back("INIT_SRC0_SHMEM_FLOAT"); - defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - variant += "_f32"; - break; - - case GGML_TYPE_F16: - src0_type_str = context.key.vectorized ? "vec4" : "f16"; - defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); - - defines.push_back("FLOAT"); - defines.push_back("MUL_ACC_FLOAT"); - defines.push_back("INIT_SRC0_SHMEM_FLOAT"); - defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - variant += "_f16"; - - break; - - default: - // convert name to upper case for other defines - std::string type_upper = src0_type_str; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - - // push back defines for quantized types - defines.push_back("BYTE_HELPERS"); - defines.push_back(type_upper + "_T"); - - defines.push_back(type_upper); - defines.push_back("MUL_ACC_" + type_upper); - defines.push_back("INIT_SRC0_SHMEM_" + type_upper); - defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - // for q4_k and q5_k - defines.push_back(type_upper + "_SCALE_MIN"); - - // defines for i-quants - defines.push_back(type_upper + "_TABLES"); - defines.push_back(type_upper + "_GRID"); - - // add variant - variant += "_"; - variant += src0_type_str; - - // add define for non-fast path quantized src0 type-- overwritten if fast - // path - defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); - - break; - } - - // Add VEC/SCALAR defines - if (is_fast_path) { - // if quantized type and using fast path, need to use f16 instead of the - // quantized type - if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) { - src0_type_str = "f16"; - defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); - } - - // all fast paths need VEC vs SCALAR - defines.push_back(context.key.vectorized ? "VEC" : "SCALAR"); - // add vec_size define too - defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u"); - - // reg_tile and subgroup_matrix need these extra defines - if (!context.key.is_vec) { - defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR"); - } - } - - // Append src1 type - variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16"); - - // printf("DEBUG: After appending src1 type: variant='%s'\n", - // variant.c_str()); - - // Vectorized suffix - if (context.key.vectorized) { - variant += "_vec"; - } - - // Add defines for TILE_M and TILE_N - defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); - defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - - // Add subgroup matrix defines if using subgroup_matrix - if (context.key.use_subgroup_matrix) { - defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); - defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); - defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); - defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); - defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); - defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); - defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); - defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); - } - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - - auto decisions = std::make_shared(); - decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; - decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; - decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K; - decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; - decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; - decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; - decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; - decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; - decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; - decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; - decisions->is_vec = context.key.is_vec; - decisions->use_subgroup_matrix = context.key.use_subgroup_matrix; - result.decisions = decisions; - - return result; -} +//inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader( +// pre_wgsl::Preprocessor & preprocessor, +// const char * shader_src, +// const ggml_webgpu_mul_mat_shader_lib_context & context) { +// std::vector defines; +// std::string variant = "mul_mat"; +// +// // Determine base variant name based on kernel type +// if (context.key.is_vec) { +// variant = "mul_mat_vec"; +// } else if (context.key.use_subgroup_matrix) { +// variant = "mul_mat_subgroup_matrix"; +// } else if (context.key.register_tile) { +// variant = "mul_mat_reg_tile"; +// } +// +// // Determine src0/src1 type strings +// const char * src0_type_str = nullptr; +// +// bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile; +// +// // src1 type +// switch (context.key.src1_type) { +// case GGML_TYPE_F32: +// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); +// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); +// break; +// case GGML_TYPE_F16: +// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); +// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); +// break; +// default: +// break; +// } +// +// // same for all types +// defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4" : "SHMEM_TYPE=f16"); +// +// // src0 type +// const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type); +// src0_type_str = src0_type_traits->type_name; +// +// // for f32 and f16, account for vectorized src0 types +// switch (context.key.src0_type) { +// case GGML_TYPE_F32: +// src0_type_str = context.key.vectorized ? "vec4" : "f32"; +// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); +// +// defines.push_back("FLOAT"); +// defines.push_back("MUL_ACC_FLOAT"); +// defines.push_back("INIT_SRC0_SHMEM_FLOAT"); +// defines.push_back("INIT_SRC1_SHMEM_FLOAT"); +// +// variant += "_f32"; +// break; +// +// case GGML_TYPE_F16: +// src0_type_str = context.key.vectorized ? "vec4" : "f16"; +// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); +// +// defines.push_back("FLOAT"); +// defines.push_back("MUL_ACC_FLOAT"); +// defines.push_back("INIT_SRC0_SHMEM_FLOAT"); +// defines.push_back("INIT_SRC1_SHMEM_FLOAT"); +// +// variant += "_f16"; +// +// break; +// +// default: +// // convert name to upper case for other defines +// std::string type_upper = src0_type_str; +// std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); +// +// // push back defines for quantized types +// defines.push_back("BYTE_HELPERS"); +// defines.push_back(type_upper + "_T"); +// +// defines.push_back(type_upper); +// defines.push_back("MUL_ACC_" + type_upper); +// defines.push_back("INIT_SRC0_SHMEM_" + type_upper); +// defines.push_back("INIT_SRC1_SHMEM_FLOAT"); +// +// // for q4_k and q5_k +// defines.push_back(type_upper + "_SCALE_MIN"); +// +// // defines for i-quants +// defines.push_back(type_upper + "_TABLES"); +// defines.push_back(type_upper + "_GRID"); +// +// // add variant +// variant += "_"; +// variant += src0_type_str; +// +// // add define for non-fast path quantized src0 type-- overwritten if fast +// // path +// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); +// +// break; +// } +// +// // Add VEC/SCALAR defines +// if (is_fast_path) { +// // if quantized type and using fast path, need to use f16 instead of the +// // quantized type +// if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) { +// src0_type_str = "f16"; +// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); +// } +// +// // all fast paths need VEC vs SCALAR +// defines.push_back(context.key.vectorized ? "VEC" : "SCALAR"); +// // add vec_size define too +// defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u"); +// +// // reg_tile and subgroup_matrix need these extra defines +// if (!context.key.is_vec) { +// defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR"); +// } +// } +// +// // Append src1 type +// variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16"); +// +// // printf("DEBUG: After appending src1 type: variant='%s'\n", +// // variant.c_str()); +// +// // Vectorized suffix +// if (context.key.vectorized) { +// variant += "_vec"; +// } +// +// // Add defines for TILE_M and TILE_N +// defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); +// defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); +// +// // Add subgroup matrix defines if using subgroup_matrix +// if (context.key.use_subgroup_matrix) { +// defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); +// defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); +// defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); +// defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); +// defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); +// defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); +// defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); +// defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); +// defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); +// } +// +// ggml_webgpu_processed_shader result; +// result.wgsl = preprocessor.preprocess(shader_src, defines); +// result.variant = variant; +// +// auto decisions = std::make_shared(); +// decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; +// decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; +// decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K; +// decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; +// decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; +// decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; +// decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; +// decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; +// decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; +// decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; +// decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; +// decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; +// decisions->is_vec = context.key.is_vec; +// decisions->use_subgroup_matrix = context.key.use_subgroup_matrix; +// result.decisions = decisions; +// +// return result; +//} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 745d78bc78..377972e1f6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -362,12 +362,9 @@ struct webgpu_context_struct { std::unordered_map - mul_mat_pipelines; // src0_type, src1_type, vectorized + mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - - std::unordered_map - flash_attn_pipelines; + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized std::map> cpy_pipelines; // src_type, dst_type @@ -891,9 +888,7 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); @@ -957,7 +952,10 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, } ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .src1 = idx, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + .src0 = src, + .src1 = idx, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -1032,13 +1030,13 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * idx, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = WEBGPU_MAX_WG_SIZE, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), @@ -1108,88 +1106,29 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - int vectorized = 0; - if (use_fast) { - vectorized = (src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0); - if (is_vec) { - // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - } - } - - // Create pipeline key - bool supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = src0->type, - .src1_type = src1->type, - .vectorized = use_fast ? vectorized : 0, - .is_vec = (use_fast && is_vec) ? 1 : 0, - .use_subgroup_matrix = - (use_fast && !is_vec && supports_subgroup_matrix) ? 1 : 0, - .register_tile = - (use_fast && !is_vec && !supports_subgroup_matrix) ? 1 : 0 }; - - // Build shader context - ggml_webgpu_mul_mat_shader_lib_context shader_lib_ctx = { .key = key, - .max_subgroup_size = - ctx->global_ctx->capabilities.max_subgroup_size, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k }; + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + }; // Get or create pipeline webgpu_pipeline pipeline; - const char * shader_src = nullptr; - auto it = ctx->mul_mat_pipelines.find(key); - if (it != ctx->mul_mat_pipelines.end()) { - pipeline = it->second; + if (use_fast && is_vec) { + pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); + } else if (use_fast) { + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } else { - // Select appropriate shader source based on key - if (!use_fast) { - // Use precompiled quantized shaders (mul_mat.tmpl.wgsl) - // These are the fallback for quantized types not supported by fast - // paths - shader_src = wgsl_mul_mat; - } else { - // Use JIT-compiled shader - if (is_vec) { - shader_src = wgsl_mul_mat_vec; - } else if (key.use_subgroup_matrix) { - shader_src = wgsl_mul_mat_subgroup_matrix; - } else { - shader_src = wgsl_mul_mat_reg_tile; - } - } - - if (shader_src) { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_mul_mat_shader(ctx->p, shader_src, shader_lib_ctx); - - std::vector constants; - if (shader_lib_ctx.key.is_vec) { - auto decisions = static_cast(processed.decisions.get()); - constants.push_back({ nullptr, "WORKGROUP_SIZE", static_cast(decisions->wg_size) }); - constants.push_back({ nullptr, "TILE_K", static_cast(decisions->tile_k) }); - constants.push_back({ nullptr, "OUTPUTS_PER_WG", static_cast(decisions->outputs_per_wg) }); - } else if (shader_lib_ctx.key.register_tile) { - auto decisions = static_cast(processed.decisions.get()); - constants.push_back({ nullptr, "WORKGROUP_SIZE_M", static_cast(decisions->wg_size_m) }); - constants.push_back({ nullptr, "WORKGROUP_SIZE_N", static_cast(decisions->wg_size_n) }); - constants.push_back({ nullptr, "TILE_K", static_cast(decisions->tile_k) }); - } - // printf("DEBUG: Creating pipeline with variant='%s', " - // "constants.size()=%zu\n", - // processed.variant.c_str(), constants.size()); - - pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), - processed.variant.c_str(), constants); - pipeline.context = processed.decisions; - ctx->mul_mat_pipelines.emplace(key, pipeline); - } + pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); } - auto decisions = static_cast(pipeline.context.get()); - // Build params std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -1230,13 +1169,17 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_x = 1; uint32_t wg_y = 1; - if (decisions->is_vec) { + if (use_fast && is_vec) { + auto decisions = static_cast(pipeline.context.get()); + uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); } else if (use_fast) { + auto decisions = static_cast(pipeline.context.get()); + // Fast-path tiled/subgroup calculations uint32_t wg_m, wg_n; if (decisions->use_subgroup_matrix) { @@ -1253,11 +1196,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; - } else { - // Non-fast-path quantized shaders (Q2_K, Q4_K, etc.) - // Use the value from decisions instead of hardcoded constant - wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->mul_mat_wg_size); - wg_y = 1; + } else { // legacy + auto decisions = static_cast(pipeline.context.get()); + uint32_t wg_size = decisions->wg_size; + wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + wg_y = 1; } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); @@ -1347,40 +1290,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = Q, + .src1 = K, + .src2 = V, + .src3 = mask, + .src4 = sinks, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline; - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1402,7 +1327,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s .inplace = inplace, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1483,7 +1408,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .overlap = flags.overlap, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1860,19 +1785,18 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src } static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool is_top_k = dst->op == GGML_OP_TOP_K; + bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, }; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); - auto * argsort_decisions = - static_cast(argsort_pipeline.context.get()); + auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx); @@ -2034,14 +1958,14 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src }; ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); - uint32_t wg_x = ggml_nrows(dst); + uint32_t wg_x = ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index a53a950a7f..771e5cd1ee 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -34,7 +34,6 @@ struct MulMatParams { broadcast3: u32 }; -// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included @group(0) @binding(0) var src0: array; // M rows, K columns @group(0) @binding(1) var src1: array; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @@ -48,14 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_M; } -override WORKGROUP_SIZE_M: u32; -override WORKGROUP_SIZE_N: u32; -override TILE_K: u32; - -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; - +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) @@ -142,4 +136,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index ac08ff31de..bbce0de9a1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -98,16 +98,13 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; -override WORKGROUP_SIZE: u32; -override TILE_K: u32; -override OUTPUTS_PER_WG: u32; -override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; +const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; // Shared memory for collaborative loading and reduction var shared_vector: array; // Cache vector tile -var partial_sums: array; // For reduction +var partial_sums: array; // For reduction -@compute @workgroup_size(WORKGROUP_SIZE) +@compute @workgroup_size(WG_SIZE) fn main( @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, @@ -150,7 +147,7 @@ fn main( let tile_size = min(TILE_K, params.k - k_tile); // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * VEC_SIZE; i < tile_size; i += WORKGROUP_SIZE * VEC_SIZE) { + for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; } @@ -168,7 +165,7 @@ fn main( workgroupBarrier(); let group_base = thread_group * THREADS_PER_OUTPUT; let thread_base = group_base + thread_in_group; - var offset = THREADS_PER_OUTPUT / 2; + var offset: u32 = THREADS_PER_OUTPUT / 2; while (offset > 0) { if (thread_in_group < offset) { partial_sums[thread_base] += partial_sums[thread_base + offset];