From b3927f807dc868d6429a80b8da84797ed9fd46cb Mon Sep 17 00:00:00 2001 From: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Tue, 10 Feb 2026 19:27:33 -0800 Subject: [PATCH] Basic JIT compilation for mul_mat, get_rows, and scale (#17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 450 ++++++++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 619 ++++++++---------- .../wgsl-shaders/common_decls.tmpl | 137 ++-- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 8 +- .../{get_rows.tmpl.wgsl => get_rows.wgsl} | 307 ++------- .../{mul_mat.tmpl.wgsl => mul_mat.wgsl} | 304 ++------- .../wgsl-shaders/mul_mat_decls.tmpl | 42 +- ...g_tile.tmpl.wgsl => mul_mat_reg_tile.wgsl} | 128 +--- ...tmpl.wgsl => mul_mat_subgroup_matrix.wgsl} | 152 +---- ...mul_mat_vec.tmpl.wgsl => mul_mat_vec.wgsl} | 134 +--- .../{scale.tmpl.wgsl => scale.wgsl} | 45 +- 11 files changed, 993 insertions(+), 1333 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl => get_rows.wgsl} (83%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl => mul_mat.wgsl} (84%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl => mul_mat_reg_tile.wgsl} (59%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl => mul_mat_subgroup_matrix.wgsl} (66%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_vec.tmpl.wgsl => mul_mat_vec.wgsl} (65%) rename ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl => scale.wgsl} (78%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 63f797f142..26bcf56b86 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -4,6 +4,7 @@ #include "ggml.h" #include "pre_wgsl.hpp" +#include #include #include #include @@ -18,6 +19,46 @@ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide +// mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + +#define WEBGPU_MAX_WG_SIZE 288 +#define WEBGPU_MUL_MAT_WG_SIZE 256 + +// helper function for replacing {{PLACEHOLDERS}} +inline void ggml_webgpu_replace_placeholder(std::string & shader_code, + const std::string & key, + const std::string & value) { + std::string pattern = "{{" + key + "}}"; + size_t pos = 0; + while ((pos = shader_code.find(pattern, pos)) != std::string::npos) { + shader_code.replace(pos, pattern.length(), value); + pos += value.length(); + } +} + struct ggml_webgpu_processed_shader { std::string wgsl; std::string variant; @@ -178,7 +219,8 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); if (context.key.kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids having to use bounds-checks and decreasing performance for direct KV loads + // Avoids having to use bounds-checks and decreasing performance for direct + // KV loads while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } @@ -466,6 +508,22 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( return result; } +/** Scale **/ + +struct ggml_webgpu_scale_pipeline_key { + int inplace; + + bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_scale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Binary **/ struct ggml_webgpu_binary_pipeline_key { @@ -490,6 +548,34 @@ struct ggml_webgpu_binary_pipeline_key_hash { } }; +struct ggml_webgpu_scale_shader_lib_context { + ggml_webgpu_scale_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_scale_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_scale_shader_lib_context & context) { + std::vector defines; + std::string variant = "scale"; + + if (context.key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + return result; +} + struct ggml_webgpu_binary_shader_lib_context { ggml_webgpu_binary_pipeline_key key; uint32_t max_wg_size; @@ -535,4 +621,366 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( result.decisions = decisions; return result; } + +/** get_rows */ + +struct ggml_webgpu_get_rows_pipeline_key { + ggml_type src_type; + int vectorized; + + bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const { + return src_type == other.src_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_get_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +struct ggml_webgpu_get_rows_shader_lib_context { + ggml_webgpu_get_rows_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_get_rows_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_get_rows_shader_lib_context & context) { + std::vector defines; + std::string variant = "get_rows"; + + // Determine src type string + const char * type_str = nullptr; + + // src type + const struct ggml_type_traits * type_traits = ggml_get_type_traits(context.key.src_type); + type_str = type_traits->type_name; + + switch (context.key.src_type) { + case GGML_TYPE_F32: + if (context.key.vectorized) { + defines.push_back("F32_VEC"); + defines.push_back("SRC_TYPE=vec4"); + defines.push_back("DST_TYPE=vec4"); + defines.push_back("BLOCK_SIZE=4u"); + } else { + defines.push_back("F32"); + defines.push_back("SRC_TYPE=f32"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + } + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("F16"); + defines.push_back("SRC_TYPE=f16"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("I32"); + defines.push_back("SRC_TYPE=i32"); + defines.push_back("DST_TYPE=i32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_i32"; + break; + default: + // convert name to upper case for other defines + std::string type_upper = type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + // push back defines for quantized types + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + + // for q4_k and q5_k + defines.push_back(type_upper + "_SCALE_MIN"); + + // defines for i-quants + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + // add variant + variant += "_"; + variant += type_str; + + // add define for quantized src0 type + defines.push_back(std::string("SRC_TYPE=") + type_str); + defines.push_back("DST_TYPE=f32"); + break; + } + + // determine block_size for quantized types + if (context.key.src_type == GGML_TYPE_I32) { + defines.push_back("BLOCK_SIZE=1u"); + } else if ((context.key.src_type >= GGML_TYPE_Q4_0 && context.key.src_type <= GGML_TYPE_Q8_1) || + context.key.src_type == GGML_TYPE_IQ4_NL) { + // Non-K quants use 32 + defines.push_back("BLOCK_SIZE=32u"); + } else if (context.key.src_type >= GGML_TYPE_Q2_K) { + // K-quants and IQ variants all use 256 + defines.push_back("BLOCK_SIZE=256u"); + } + + // Vectorized suffix + if (context.key.vectorized) { + variant += "_vec"; + } + + defines.push_back("WORKGROUP_SIZE=" + std::to_string(context.max_wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + + // Create decisions structure to store workgroup size + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + + return result; +} + +/** Matrix Multiplication **/ + +struct ggml_webgpu_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + int is_vec; + int use_subgroup_matrix; + int register_tile; + + bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + is_vec == other.is_vec && use_subgroup_matrix == other.use_subgroup_matrix && + register_tile == other.register_tile; + } +}; + +struct ggml_webgpu_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.is_vec); + ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); + ggml_webgpu_hash_combine(seed, key.register_tile); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_shader_lib_context { + ggml_webgpu_mul_mat_pipeline_key key; + + // For subgroup matrix paths + uint32_t max_subgroup_size; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; +}; + +struct ggml_webgpu_mul_mat_shader_decisions { + uint32_t tile_k; + uint32_t wg_size_m; + uint32_t wg_size_n; + uint32_t wg_size; + uint32_t outputs_per_wg; + int is_vec; + int use_subgroup_matrix; + + // Add new fields for all the parameters + uint32_t tile_m; + uint32_t tile_n; + + // Subgroup matrix parameters + uint32_t subgroup_m; + uint32_t subgroup_n; + uint32_t subgroup_matrix_m; + uint32_t subgroup_matrix_n; + + uint32_t mul_mat_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_mul_mat_shader_lib_context & context) { + std::vector defines; + std::string variant = "mul_mat"; + + // Determine base variant name based on kernel type + if (context.key.is_vec) { + variant = "mul_mat_vec"; + } else if (context.key.use_subgroup_matrix) { + variant = "mul_mat_subgroup_matrix"; + } else if (context.key.register_tile) { + variant = "mul_mat_reg_tile"; + } + + // Determine src0/src1 type strings + const char * src0_type_str = nullptr; + + bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile; + + // src1 type + switch (context.key.src1_type) { + case GGML_TYPE_F32: + defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); + defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); + defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + break; + default: + break; + } + + // same for all types + defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4" : "SHMEM_TYPE=f16"); + + // src0 type + const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type); + src0_type_str = src0_type_traits->type_name; + + // for f32 and f16, account for vectorized src0 types + switch (context.key.src0_type) { + case GGML_TYPE_F32: + src0_type_str = context.key.vectorized ? "vec4" : "f32"; + defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); + + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + variant += "_f32"; + break; + + case GGML_TYPE_F16: + src0_type_str = context.key.vectorized ? "vec4" : "f16"; + defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + variant += "_f16"; + + break; + + default: + // convert name to upper case for other defines + std::string type_upper = src0_type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + // push back defines for quantized types + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + + defines.push_back(type_upper); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + // for q4_k and q5_k + defines.push_back(type_upper + "_SCALE_MIN"); + + // defines for i-quants + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + // add variant + variant += "_"; + variant += src0_type_str; + + // add define for non-fast path quantized src0 type-- overwritten if fast + // path + defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); + + break; + } + + // Add VEC/SCALAR defines + if (is_fast_path) { + // if quantized type and using fast path, need to use f16 instead of the + // quantized type + if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) { + src0_type_str = "f16"; + defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); + } + + // all fast paths need VEC vs SCALAR + defines.push_back(context.key.vectorized ? "VEC" : "SCALAR"); + // add vec_size define too + defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u"); + + // reg_tile and subgroup_matrix need these extra defines + if (!context.key.is_vec) { + defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR"); + } + } + + // Append src1 type + variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16"); + + // printf("DEBUG: After appending src1 type: variant='%s'\n", + // variant.c_str()); + + // Vectorized suffix + if (context.key.vectorized) { + variant += "_vec"; + } + + // Add defines for TILE_M and TILE_N + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + + // Add subgroup matrix defines if using subgroup_matrix + if (context.key.use_subgroup_matrix) { + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); + defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); + defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); + defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); + defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); + } + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + + auto decisions = std::make_shared(); + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K; + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; + decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; + decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; + decisions->is_vec = context.key.is_vec; + decisions->use_subgroup_matrix = context.key.use_subgroup_matrix; + result.decisions = decisions; + + return result; +} + #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 32e120266a..745d2a5e88 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -69,21 +69,24 @@ /* Constants */ -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. +// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to +// implementations so this can be removed. #define WEBGPU_MAX_WG_SIZE 288 #define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 16u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +// Maximum number of in-flight submissions per-thread, to avoid exhausting the +// parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable default +// For operations which process a row in parallel, this seems like a reasonable +// default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 // Matrix multiplication parameters @@ -106,13 +109,15 @@ // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size +// Must be multiple of 4 to work with vectorized paths, and must divide +// mul_mat_vec wg size #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 #define WEBGPU_MUL_MAT_VEC_TILE_K 256 /* End Constants */ -// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to +// their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT // Always returns the base offset of a tensor, regardless of views. @@ -358,9 +363,12 @@ struct webgpu_context_struct { webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::unordered_map + mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized std::unordered_map flash_attn_pipelines; @@ -372,10 +380,13 @@ struct webgpu_context_struct { std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet std::unordered_map - set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized + set_rows_pipelines; + std::unordered_map + get_rows_pipelines; // src_type, vectorized - std::map> cpy_pipelines; // src_type, dst_type + std::map> cpy_pipelines; // src_type, dst_type std::unordered_map binary_pipelines; @@ -383,7 +394,11 @@ struct webgpu_context_struct { std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split - std::map scale_pipelines; // inplace + + std::unordered_map + scale_pipelines; // inplace std::map>> soft_max_pipelines; // mask_type, has_sink, inplace std::unordered_map unary_pipelines; @@ -495,8 +510,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, - // inflight_max may be 0, meaning that we must wait on all futures. + // If we have too many in-flight submissions, wait on the oldest one first. If + // there are many threads, inflight_max may be 0, meaning that we must wait on + // all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; uint32_t inflight_threads = ctx->inflight_threads; uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); @@ -681,7 +697,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); } - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + // If there are SET_ROWS operations in this submission, copy their error + // buffers to the host. if (set_rows_error_bufs) { encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, set_rows_error_bufs->host_buf.GetSize()); @@ -971,7 +988,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - // For set rows specifically, we need to check if src and idx are empty tensors. + // For set rows specifically, we need to check if src and idx are empty + // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { return std::nullopt; } @@ -1054,25 +1072,68 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, error_bufs); } +// Workgroup size is a common constant +static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = wg_size; + return constants; +} + static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of dst - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], - // Shape of idx - (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) - }; + uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; + + // Create pipeline key + ggml_webgpu_get_rows_pipeline_key key = { .src_type = src->type, .vectorized = (int) vectorized }; + + // Get or create pipeline + webgpu_pipeline pipeline; + auto it = ctx->get_rows_pipelines.find(key); + if (it != ctx->get_rows_pipelines.end()) { + pipeline = it->second; + } else { + // JIT compile the shader + const char * shader_src = wgsl_get_rows; + + // Build shader context + ggml_webgpu_get_rows_shader_lib_context shader_lib_ctx = { + .key = key, + .max_wg_size = WEBGPU_MAX_WG_SIZE, + }; + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_get_rows_shader(ctx->p, shader_src, shader_lib_ctx); + + std::vector constants = ggml_webgpu_wg_size_entry(shader_lib_ctx.max_wg_size); + pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), + processed.variant.c_str(), constants); + pipeline.context = processed.decisions; + ctx->get_rows_pipelines.emplace(key, pipeline); + } + + auto decisions = static_cast(pipeline.context.get()); + + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; std::vector entries = { { .binding = 0, @@ -1089,10 +1150,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); - uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1100,45 +1159,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[0], // number of rows in result (M, transposed) - (uint32_t) dst->ne[1], // number of columns in result (N) - (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 - (uint32_t) src0->ne[2], // batch size in dimension 2 - (uint32_t) src0->ne[3], // batch size in dimension 3 - (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 - (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 - }; - - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - - webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - - uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); - uint32_t wg_y = 1; + // Determine if this is a mat-vec operation + bool is_vec = (dst->ne[1] == 1); + // Determine if we should use fast path bool use_fast = false; switch (src1->type) { case GGML_TYPE_F16: @@ -1159,43 +1183,158 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, break; } + int vectorized = 0; if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - if (dst->ne[1] == 1) { + vectorized = (src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0); + if (is_vec) { // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); - } else { - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t wg_m; - uint32_t wg_n; -#ifndef __EMSCRIPTEN__ - if (ctx->global_ctx->capabilities.supports_subgroup_matrix) { - // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * - ctx->global_ctx->capabilities.sg_mat_m; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * - ctx->global_ctx->capabilities.sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); - } else { -#endif - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = CEIL_DIV(dst->ne[0], tile_m_s); - wg_n = CEIL_DIV(dst->ne[1], tile_n_s); -#ifndef __EMSCRIPTEN__ - } -#endif - - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + vectorized = vectorized && (src0->type < 2); } } + + // Create pipeline key + bool supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = src0->type, + .src1_type = src1->type, + .vectorized = use_fast ? vectorized : 0, + .is_vec = (use_fast && is_vec) ? 1 : 0, + .use_subgroup_matrix = + (use_fast && !is_vec && supports_subgroup_matrix) ? 1 : 0, + .register_tile = + (use_fast && !is_vec && !supports_subgroup_matrix) ? 1 : 0 }; + + // Build shader context + ggml_webgpu_mul_mat_shader_lib_context shader_lib_ctx = { .key = key, + .max_subgroup_size = + ctx->global_ctx->capabilities.max_subgroup_size, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k }; + + // Get or create pipeline + webgpu_pipeline pipeline; + const char * shader_src = nullptr; + + auto it = ctx->mul_mat_pipelines.find(key); + if (it != ctx->mul_mat_pipelines.end()) { + pipeline = it->second; + } else { + // Select appropriate shader source based on key + if (!use_fast) { + // Use precompiled quantized shaders (mul_mat.tmpl.wgsl) + // These are the fallback for quantized types not supported by fast + // paths + shader_src = wgsl_mul_mat; + } else { + // Use JIT-compiled shader + if (is_vec) { + shader_src = wgsl_mul_mat_vec; + } else if (key.use_subgroup_matrix) { + shader_src = wgsl_mul_mat_subgroup_matrix; + } else { + shader_src = wgsl_mul_mat_reg_tile; + } + } + + if (shader_src) { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_mul_mat_shader(ctx->p, shader_src, shader_lib_ctx); + + std::vector constants; + if (shader_lib_ctx.key.is_vec) { + auto decisions = static_cast(processed.decisions.get()); + constants.push_back({ nullptr, "WORKGROUP_SIZE", static_cast(decisions->wg_size) }); + constants.push_back({ nullptr, "TILE_K", static_cast(decisions->tile_k) }); + constants.push_back({ nullptr, "OUTPUTS_PER_WG", static_cast(decisions->outputs_per_wg) }); + } else if (shader_lib_ctx.key.register_tile) { + auto decisions = static_cast(processed.decisions.get()); + constants.push_back({ nullptr, "WORKGROUP_SIZE_M", static_cast(decisions->wg_size_m) }); + constants.push_back({ nullptr, "WORKGROUP_SIZE_N", static_cast(decisions->wg_size_n) }); + constants.push_back({ nullptr, "TILE_K", static_cast(decisions->tile_k) }); + } + // printf("DEBUG: Creating pipeline with variant='%s', " + // "constants.size()=%zu\n", + // processed.variant.c_str(), constants.size()); + + pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), + processed.variant.c_str(), constants); + pipeline.context = processed.decisions; + ctx->mul_mat_pipelines.emplace(key, pipeline); + } + } + + auto decisions = static_cast(pipeline.context.get()); + + // Build params + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) src0->ne[0], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src0->ne[2], + (uint32_t) src0->ne[3], + (uint32_t) (src1->ne[2] / src0->ne[2]), + (uint32_t) (src1->ne[3] / src0->ne[3]) + }; + + // Build bind group entries + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + + // Calculate workgroup dimensions + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + if (decisions->is_vec) { + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + } else if (use_fast) { + // Fast-path tiled/subgroup calculations + uint32_t wg_m, wg_n; + if (decisions->use_subgroup_matrix) { + uint32_t wg_m_sg_tile = + decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = + decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + } else { + uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m; + uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } else { + // Non-fast-path quantized shaders (Q2_K, Q4_K, etc.) + // Use the value from decisions instead of hardcoded constant + wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->mul_mat_wg_size); + wg_y = 1; + } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } @@ -1671,6 +1810,29 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); + ggml_webgpu_scale_pipeline_key key = { .inplace = inplace }; + + ggml_webgpu_scale_shader_lib_context shader_lib_ctx = { + .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; + + webgpu_pipeline pipeline; + // TODO: remove guard once pipeline caches are per-thread + auto it = ctx->scale_pipelines.find(key); + if (it != ctx->scale_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_scale_shader(ctx->p, wgsl_scale, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->scale_pipelines.emplace(key, pipeline); + } + + auto * decisions = static_cast(pipeline.context.get()); + + // params unchanged std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1688,12 +1850,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, *(uint32_t *) &dst->op_params[1] // bias }; + // bindgroups unchanged std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), @@ -1702,8 +1866,7 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params, - entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -2233,7 +2396,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe size_t offset, size_t size) { if (size == 0) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); + WEBGPU_LOG_DEBUG( + "ggml_backend_webgpu_buffer_memset_tensor: size is zero, " + "nothing to do."); return; } @@ -2310,7 +2475,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t final_size = size; if (size % 4 != 0) { - // If size is not a multiple of 4, we need to round it up to the next multiple of 4 + // If size is not a multiple of 4, we need to round it up to the next + // multiple of 4 final_size = size + (4 - (size % 4)); } @@ -2364,7 +2530,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with + // .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -2401,7 +2568,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } -// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. +// maxBufferSize might be larger, but you can't bind more than +// maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_webgpu_device_context * dev_ctx = static_cast(buft->device->context); @@ -2487,14 +2655,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } -// Workgroup size is a common constant -static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = wg_size; - return constants; -} - static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -2509,207 +2669,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - // Q4/Q5/Q8 classic quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - - // K-quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - - // IQ quantizations (2-, 3-, 4-bit variants) - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - - // 1-bit and 4-bit IQ variants - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - - std::string proc_mul_mat_f32_f32; - std::string proc_mul_mat_f32_f32_vec; - std::string proc_mul_mat_f16_f32; - std::string proc_mul_mat_f16_f32_vec; - std::string proc_mul_mat_f16_f16; - std::string proc_mul_mat_f16_f16_vec; - std::string proc_mul_mat_q4_0_f32; - std::string proc_mul_mat_q4_0_f32_vec; - - std::vector mul_mat_constants; -#ifndef __EMSCRIPTEN__ - if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) { - std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = - std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); - sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); - sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); - sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k); - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - proc_mul_mat_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - proc_mul_mat_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - proc_mul_mat_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - proc_mul_mat_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - proc_mul_mat_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - } else { -#endif - mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); - - std::map reg_repls; - reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); - reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); -#ifndef __EMSCRIPTEN__ - } -#endif - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); - - std::vector mul_mat_vec_constants(3); - mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; - mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; - mul_mat_vec_constants[1].key = "TILE_K"; - mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; - mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; - mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); -} - -static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); -} - static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2816,15 +2775,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } -static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, - "scale_f32_inplace", constants); -} - static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); @@ -3023,13 +2973,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); - ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_scale_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -3071,11 +3018,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, - /* .is_host = */ NULL, // defaults to false + /* .alloc_buffer = */ + ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ + ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ + ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ + ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 389c97bb51..9a5b18ebc0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -1,5 +1,4 @@ -#decl(BYTE_HELPERS) - +#ifdef BYTE_HELPERS fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 { fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#endif -#enddecl(BYTE_HELPERS) - -#decl(Q4_0_T) +#ifdef Q4_0_T struct q4_0 { d: f16, qs: array }; -#enddecl(Q4_0_T) +#endif -#decl(Q4_1_T) +#ifdef Q4_1_T struct q4_1 { d: f16, m: f16, qs: array }; -#enddecl(Q4_1_T) +#endif -#decl(Q5_0_T) +#ifdef Q5_0_T struct q5_0 { d: f16, qh: array, qs: array }; -#enddecl(Q5_0_T) +#endif -#decl(Q5_1_T) +#ifdef Q5_1_T struct q5_1 { d: f16, m: f16, qh: u32, qs: array }; -#enddecl(Q5_1_T) +#endif -#decl(Q8_0_T) +#ifdef Q8_0_T struct q8_0 { d: f16, qs: array }; -#enddecl(Q8_0_T) +#endif -#decl(Q8_1_T) +#ifdef Q8_1_T struct q8_1 { d: f16, m: f16, qs: array }; -#enddecl(Q8_1_T) +#endif -#decl(Q2_K_T) -struct q2_k { +#ifdef Q2_K_T +struct q2_K { scales: array, qs: array, d: f16, dmin: f16 }; -#enddecl(Q2_K_T) +#endif -#decl(Q3_K_T) -struct q3_k { +#ifdef Q3_K_T +struct q3_K { hmask: array, qs: array, scales: array, d: f16 }; -#enddecl(Q3_K_T) - -#decl(Q45_K_SCALE_MIN) +#endif +#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { if (is < 4) { let sc_byte = get_byte(scales[is / 4], is % 4); @@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array) -> vec2 { return vec2(f32(sc), f32(m)); } } - -#enddecl(Q45_K_SCALE_MIN) - -#decl(Q4_K_T) -struct q4_k { +#endif +#ifdef Q4_K_T +struct q4_K { d: f16, dmin: f16, scales: array, qs: array }; -#enddecl(Q4_K_T) +#endif -#decl(Q5_K_T) -struct q5_k { +#ifdef Q5_K_T +struct q5_K { d: f16, dmin: f16, scales: array, qh: array, qs: array }; -#enddecl(Q5_K_T) +#endif -#decl(Q6_K_T) -struct q6_k { +#ifdef Q6_K_T +struct q6_K { ql: array, qh: array, scales: array, d: f16 }; -#enddecl(Q6_K_T) +#endif -#decl(IQ2_XXS_T) +#ifdef IQ2_XXS_T struct iq2_xxs { d: f16, qs: array }; -#enddecl(IQ2_XXS_T) +#endif -#decl(IQ2_XS_T) +#ifdef IQ2_XS_T struct iq2_xs { d: f16, qs: array, scales: array }; -#enddecl(IQ2_XS_T) +#endif -#decl(IQ2_S_T) +#ifdef IQ2_S_T struct iq2_s { d: f16, qs: array, qh: array, scales: array }; -#enddecl(IQ2_S_T) +#endif -#decl(IQ3_XSS_T) +#ifdef IQ3_XXS_T struct iq3_xxs { d: f16, qs: array }; -#enddecl(IQ3_XSS_T) +#endif -#decl(IQ3_S_T) +#ifdef IQ3_S_T struct iq3_s { d: f16, qs: array, @@ -161,41 +156,41 @@ struct iq3_s { signs: array, scales: array }; -#enddecl(IQ3_S_T) +#endif -#decl(IQ1_S_T) +#ifdef IQ1_S_T struct iq1_s { d: f16, qs: array, qh: array }; -#enddecl(IQ1_S_T) +#endif -#decl(IQ1_M_T) +#ifdef IQ1_M_T struct iq1_m { qs: array, qh: array, scales: array }; -#enddecl(IQ1_M_T) +#endif -#decl(IQ4_NL_T) +#ifdef IQ4_NL_T struct iq4_nl { d: f16, qs: array, }; -#enddecl(IQ4_NL_T) +#endif -#decl(IQ4_XS_T) +#ifdef IQ4_XS_T struct iq4_xs { d: f16, scales_h: f16, scales_l: u32, qs: array }; -#enddecl(IQ4_XS_T) +#endif -#decl(IQ23_TABLES) +#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES) const kmask_iq2xs : array = array( 0x08040201u, // 1, 2, 4, 8 0x80402010u // 16, 32, 64, 128 @@ -211,9 +206,9 @@ const ksigns_iq2xs: array = array( 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc ); -#enddecl(IQ23_TABLES) +#endif -#decl(IQ2_XXS_GRID) +#ifdef IQ2_XXS_GRID const iq2xxs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, @@ -280,9 +275,9 @@ const iq2xxs_grid = array( 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 ); -#enddecl(IQ2_XXS_GRID) +#endif -#decl(IQ2_XS_GRID) +#ifdef IQ2_XS_GRID const iq2xs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -413,9 +408,9 @@ const iq2xs_grid = array( 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_XS_GRID) +#endif -#decl(IQ2_S_GRID) +#ifdef IQ2_S_GRID const iq2s_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -674,10 +669,9 @@ const iq2s_grid = array( 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_S_GRID) - -#decl(IQ3_XSS_GRID) +#endif +#ifdef IQ3_XXS_GRID const iq3xxs_grid = array( 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -712,10 +706,9 @@ const iq3xxs_grid = array( 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 ); -#enddecl(IQ3_XSS_GRID) - -#decl(IQ3_S_GRID) +#endif +#ifdef IQ3_S_GRID const iq3s_grid = array( 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, @@ -782,9 +775,9 @@ const iq3s_grid = array( 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 ); -#enddecl(IQ3_S_GRID) +#endif -#decl(IQ1_GRID) +#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID) const IQ1_DELTA: f32 = 0.125; @@ -919,12 +912,12 @@ const iq1_grid = array( 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 ); -#enddecl(IQ1_GRID) +#endif -#decl(IQ4_GRID) +#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID) const kvalues_iq4nl = array( -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 ); -#enddecl(IQ4_GRID) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d61df5bb9e..2ce30edb90 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -56,7 +56,9 @@ def expand_includes(shader, input_dir): return include_pattern.sub(replacer, shader) -def write_shader(shader_name, shader_code, output_dir, outfile): +def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): + shader_code = expand_includes(shader_code, input_dir) + if output_dir: wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") with open(wgsl_filename, "w", encoding="utf-8") as f_out: @@ -74,7 +76,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): try: variants = ast.literal_eval(extract_block(text, "VARIANTS")) except ValueError: - write_shader(shader_base_name, text, output_dir, outfile) + write_shader(shader_base_name, text, output_dir, outfile, input_dir) else: try: decls_map = parse_decls(extract_block(text, "DECLS")) @@ -123,7 +125,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile) + write_shader(output_name, final_shader, output_dir, outfile, input_dir) def main(): diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl similarity index 83% rename from ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index f80ce1fc55..8e0f401e23 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,222 +1,31 @@ -#define(VARIANTS) +enable f16; +#include "common_decls.tmpl" -[ - { - "SHADER_SUFFIX": "f32_vec", - "REPLS": { - "TYPE" : "vec4", - "DST_TYPE": "vec4", - "BLOCK_SIZE": 4 - }, - "DECLS": ["F32_VEC"] - }, - { - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F32"] - }, - { - "REPLS": { - "TYPE" : "f16", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F16"] - }, - { - "REPLS": { - "TYPE" : "i32", - "DST_TYPE": "i32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["I32"] - }, - { - "REPLS": { - "TYPE" : "q4_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "TYPE" : "q4_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "TYPE" : "q5_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "TYPE" : "q5_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "TYPE" : "q8_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "TYPE" : "q2_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "TYPE" : "q3_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "TYPE" : "q4_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "TYPE" : "q5_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "TYPE" : "q6_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "TYPE" : "iq2_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "TYPE" : "iq2_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "TYPE": "iq2_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "TYPE": "iq3_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "TYPE": "iq3_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "TYPE": "iq1_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "TYPE": "iq1_m", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "TYPE": "iq4_nl", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "TYPE": "iq4_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(F32_VEC) +#ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; } -#enddecl(F32_VEC) +#endif -#decl(F32) +#ifdef F32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(F32) +#endif -#decl(F16) +#ifdef F16 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = f32(src[src_base + offset]); } -#enddecl(F16) +#endif -#decl(I32) +#ifdef I32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(I32) +#endif -#decl(Q4_0) +#ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_0 = src[src_base + offset]; let d = f32(block_q4_0.d); @@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_1 = src[src_base + offset]; let d = f32(block_q4_1.d); @@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_0 = src[src_base + offset]; let d = f32(block_q5_0.d); @@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(Q5_0) - -#decl(Q5_1) +#ifdef Q5_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_1 = src[src_base + offset]; let d = f32(block_q5_1.d); @@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q8_0 = src[src_base + offset]; let d = f32(block_q8_0.d); @@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q8_0) +#endif -#decl(Q2_K) +#ifdef Q2_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q2_K) +#endif -#decl(Q3_K) +#ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q3_K) +#endif -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_K) +#endif -#decl(Q5_K) +#ifdef Q5_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_K) +#endif -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { sc_b_idx += 8; } } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XXS) +#endif -#decl(IQ2_XS) +#ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XS) +#endif -#decl(IQ2_S) +#ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_XSS) +#endif -#decl(IQ3_S) +#ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_S) - -#decl(IQ1_M) +#ifdef IQ1_M fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i++; } } -#enddecl(IQ4_NL) +#endif -#decl(IQ4_XS) +#ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i += 16; } } -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS +#endif @group(0) @binding(0) -var src: array<{{TYPE}}>; +var src: array; @group(0) @binding(1) var idx: array; @group(0) @binding(2) -var dst: array<{{DST_TYPE}}>; +var dst: array; struct Params { offset_src: u32, // in elements @@ -866,9 +662,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; - for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { + for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl similarity index 84% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 0f8e6e5ac3..6aba47317c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,195 +1,24 @@ -#define(VARIANTS) +enable f16; -[ - { - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q8_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q2_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q3_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q6_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_m", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_nl", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] +#include "common_decls.tmpl" -#end(VARIANTS) +#ifdef FLOAT +const BLOCK_SIZE = 1u; -#define(DECLS) +#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) +const BLOCK_SIZE = 32u; -#decl(FLOAT) +#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) +const BLOCK_SIZE = 256u; +#endif + +#ifdef FLOAT fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); } -#enddecl(FLOAT) +#endif -#decl(Q4_0) +#ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q4_0 = src0[src0_idx_base + offset]; let d = f32(block_q4_0.d); @@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q4_1 = src0[src0_idx_base + offset]; let d = f32(block_q4_1.d); @@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q5_0 = src0[src0_idx_base + offset]; let d = f32(block_q5_0.d); @@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q5_0) +#endif -#decl(Q5_1) +#ifdef Q5_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q5_1 = src0[src0_idx_base + offset]; let d = f32(block_q5_1.d); @@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q8_0 = src0[src0_idx_base + offset]; let d = f32(block_q8_0.d); @@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q8_0) +#endif -#decl(Q8_1) +#ifdef Q8_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q8_1 = src0[src0_idx_base + offset]; let d = f32(block_q8_1.d); @@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q8_1) +#endif -#decl(Q2_K) +#ifdef Q2_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q2_K) - -#decl(Q3_K) +#ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q3_K) - -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q4_K) - -#decl(Q5_K) +#ifdef Q5_K // 8 blocks of 32 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q5_K) - -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ2_XXS) - -#decl(IQ2_XS) +#ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ2_XS) - -#decl(IQ2_S) +#ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif - -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ3_XSS) - -#decl(IQ3_S) +#ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ1_S) -#decl(IQ1_M) +#ifdef IQ1_M fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ4_NL) - -#decl(IQ4_XS) +#ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } - -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS +#endif struct MulMatParams { offset_src0: u32, // in elements/blocks @@ -864,8 +672,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; @@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; var sum = 0.0; - for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { + for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 109ff8d615..6827995bf7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -1,58 +1,53 @@ -#decl(SHMEM_VEC) +#ifdef SHMEM_VEC fn store_shmem(val: vec4, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#enddecl(SHMEM_VEC) +#endif -#decl(SHMEM_SCALAR) +#ifdef SHMEM_SCALAR fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#enddecl(SHMEM_SCALAR) - -#decl(INIT_SRC0_SHMEM_FLOAT) +#endif +#ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], + SRC0_TYPE(0.0), + src0[src0_idx/VEC_SIZE], global_m < params.m && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } +#endif -#enddecl(INIT_SRC0_SHMEM_FLOAT) - -#decl(INIT_SRC1_SHMEM) - +#ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_n = offset_n + tile_n; let global_k = k_outer + tile_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], global_n < params.n && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } +#endif -#enddecl(INIT_SRC1_SHMEM) - -#decl(INIT_SRC0_SHMEM_Q4_0) - +#ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; @@ -93,5 +88,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } } - -#enddecl(INIT_SRC0_SHMEM_Q4_0) +#endif \ No newline at end of file diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl similarity index 59% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 6b1dd26cd9..a53a950a7f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,115 +1,19 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] +enable f16; -#end(VARIANTS) +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" -#define(DECLS) - -#decl(VEC) +#ifdef VEC fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { return f32(acc[tm][tn]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -enable f16; +#endif struct MulMatParams { offset_src0: u32, @@ -130,14 +34,13 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -DECLS - fn get_local_n(thread_id: u32) -> u32 { return thread_id / WORKGROUP_SIZE_M; } @@ -145,10 +48,6 @@ fn get_local_m(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_M; } -// TILE_M must be multiple of 4 for vec4 loads -const TILE_M = {{WEBGPU_TILE_M}}u; -const TILE_N = {{WEBGPU_TILE_N}}u; - override WORKGROUP_SIZE_M: u32; override WORKGROUP_SIZE_N: u32; override TILE_K: u32; @@ -233,15 +132,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var tn = 0u; tn < TILE_N; tn++) { let global_col = output_col_base + tn; if (global_col < params.n) { - for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { let global_row = output_row_base + tm; if (global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); } } } } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl similarity index 66% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 47c8ce36ab..64529e03cd 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -1,100 +1,12 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; -#end(VARIANTS) +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" -#define(DECLS) - -#decl(VEC) +#ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( f32(shmem[shmem_idx]), @@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { f32(shmem[shmem_idx + 3]) ); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = f32(shmem[shmem_idx]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -diagnostic(off, chromium.subgroup_matrix_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; +#endif struct MulMatParams { offset_src0: u32, @@ -138,36 +42,19 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -DECLS - -// Note: These are string interpolated at build time, cannot use override constants due to limitations in -// current Dawn version type definitions/matrix load requirements for constant memory sizes. -const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; -const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; -// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the -// runtime subgroup size is smaller. -const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; - -const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; - -const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; -const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; -const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; - -const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; -const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; - -const TILE_K = {{WEBGPU_TILE_K}}u; - const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; @@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let local_row = idx % WG_TILE_STRIDE; let local_col = idx / WG_TILE_STRIDE; @@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_col < params.n && global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - store_dst(idx, dst_idx/{{VEC_SIZE}}); + store_dst(idx, dst_idx/VEC_SIZE); } } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl similarity index 65% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index ffbb640328..ac08ff31de 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,84 +1,11 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] - } -] -#end(VARIANTS) +enable f16; -#define(DECLS) +#include "common_decls.tmpl" -#decl(VEC) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +#ifdef VEC +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); } fn store_val(group_base: u32) -> vec4 { @@ -87,33 +14,31 @@ fn store_val(group_base: u32) -> vec4 { partial_sums[group_base + THREADS_PER_OUTPUT * 2], partial_sums[group_base + THREADS_PER_OUTPUT * 3]); } -#enddecl(VEC) +#endif -#decl(SCALAR) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { +#ifdef SCALAR +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } fn store_val(group_base: u32) -> f32 { return partial_sums[group_base]; } -#enddecl(SCALAR) - -#decl(MUL_ACC_FLOAT) +#endif +#ifdef MUL_ACC_FLOAT fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { var local_sum = 0.0; - for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; + for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { + let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; + let b = shared_vector[i / VEC_SIZE]; local_sum += inner_dot(a, b); } return local_sum; } +#endif -#enddecl(MUL_ACC_FLOAT) - -#decl(MUL_ACC_Q4_0) +#ifdef MUL_ACC_Q4_0 const BLOCK_SIZE = 32; const NQ = 16u; // number of weights per thread @@ -145,15 +70,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } return local_sum; } - -#enddecl(MUL_ACC_Q4_0) - -#end(DECLS) - -#define(SHADER) -enable f16; - -DECLS +#endif struct MulMatParams { offset_src0: u32, @@ -174,9 +91,10 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; @@ -186,7 +104,7 @@ override OUTPUTS_PER_WG: u32; override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; // Shared memory for collaborative loading and reduction -var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile +var shared_vector: array; // Cache vector tile var partial_sums: array; // For reduction @compute @workgroup_size(WORKGROUP_SIZE) @@ -232,8 +150,8 @@ fn main( let tile_size = min(TILE_K, params.k - k_tile); // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { - shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + for (var i = thread_id * VEC_SIZE; i < tile_size; i += WORKGROUP_SIZE * VEC_SIZE) { + shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; } workgroupBarrier(); @@ -260,8 +178,8 @@ fn main( } // Store back to global memory - if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { - dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { + dst[dst_idx / VEC_SIZE] = store_val(group_base); } } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl similarity index 78% rename from ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 040e80dfea..3b70a876d7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -1,21 +1,11 @@ -#define(VARIANTS) +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; -[ - { - "SHADER_NAME": "scale_f32", - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "scale_f32_inplace", - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#else @group(0) @binding(1) var dst: array; @@ -25,20 +15,7 @@ var params: Params; fn store_scale(val: f32, offset: u32) { dst[offset] = val; } -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -@group(0) @binding(1) -var params: Params; - -fn store_scale(val: f32, offset: u32) { - src[offset] = val; -} -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) +#endif struct Params { offset_src: u32, @@ -65,10 +42,7 @@ struct Params { @group(0) @binding(0) var src: array; -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { return; @@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3) { store_scale(src[i_src] * params.scale + params.bias, i_dst); } -#end(SHADER)