diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 02c11e60a8..7e9f4eee62 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -731,11 +731,11 @@ class ggml_webgpu_shader_lib { // src1 type (vector) switch (context.src1->type) { case GGML_TYPE_F32: - defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); + defines.push_back("SRC1_INNER_TYPE=f32"); variant += "_f32"; break; case GGML_TYPE_F16: - defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); + defines.push_back("SRC1_INNER_TYPE=f16"); variant += "_f16"; break; default: @@ -745,11 +745,11 @@ class ggml_webgpu_shader_lib { // src0 type (matrix row) switch (context.src0->type) { case GGML_TYPE_F32: - defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); + defines.push_back("SRC0_INNER_TYPE=f32"); defines.push_back("MUL_ACC_FLOAT"); break; case GGML_TYPE_F16: - defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + defines.push_back("SRC0_INNER_TYPE=f16"); defines.push_back("MUL_ACC_FLOAT"); break; default: @@ -764,22 +764,13 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); // For fast path we always dequantize from f16 inside the shader - defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + defines.push_back("SRC0_INNER_TYPE=f16"); break; } } - // dst type - defines.push_back(key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); - - // vec/scalar controls - if (key.vectorized) { - defines.push_back("VEC"); - defines.push_back("VEC_SIZE=4u"); - } else { - defines.push_back("SCALAR"); - defines.push_back("VEC_SIZE=1u"); - } + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; @@ -824,27 +815,22 @@ class ggml_webgpu_shader_lib { // src1 type switch (context.src1->type) { case GGML_TYPE_F32: - defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); - defines.push_back(key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + defines.push_back("SRC1_INNER_TYPE=f32"); break; case GGML_TYPE_F16: - defines.push_back(key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); - defines.push_back(key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); + defines.push_back("SRC1_INNER_TYPE=f16"); break; default: GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); } - // Shared memory type - defines.push_back(key.vectorized ? "SHMEM_TYPE=vec4" : "SHMEM_TYPE=f16"); - // src0 type const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); const char * src0_name = src0_traits->type_name; switch (context.src0->type) { case GGML_TYPE_F32: - defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); + defines.push_back("SRC0_INNER_TYPE=f32"); defines.push_back("FLOAT"); defines.push_back("MUL_ACC_FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); @@ -852,7 +838,7 @@ class ggml_webgpu_shader_lib { variant += "_f32"; break; case GGML_TYPE_F16: - defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + defines.push_back("SRC0_INNER_TYPE=f16"); defines.push_back("FLOAT"); defines.push_back("MUL_ACC_FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); @@ -870,7 +856,7 @@ class ggml_webgpu_shader_lib { defines.push_back("INIT_SRC1_SHMEM_FLOAT"); // Use f16 inside the shader for quantized types - defines.push_back(key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); + defines.push_back("SRC0_INNER_TYPE=f16"); variant += std::string("_") + src0_name; break; @@ -879,8 +865,6 @@ class ggml_webgpu_shader_lib { // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); - defines.push_back(key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u"); - defines.push_back(key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR"); // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); @@ -1244,197 +1228,4 @@ class ggml_webgpu_shader_lib { } }; -// helper function for replacing {{PLACEHOLDERS}} -inline void ggml_webgpu_replace_placeholder(std::string & shader_code, - const std::string & key, - const std::string & value) { - std::string pattern = "{{" + key + "}}"; - size_t pos = 0; - while ((pos = shader_code.find(pattern, pos)) != std::string::npos) { - shader_code.replace(pos, pattern.length(), value); - pos += value.length(); - } -} - -struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - std::shared_ptr decisions; -}; - -/** Matrix Multiplication **/ - -//inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader( -// pre_wgsl::Preprocessor & preprocessor, -// const char * shader_src, -// const ggml_webgpu_mul_mat_shader_lib_context & context) { -// std::vector defines; -// std::string variant = "mul_mat"; -// -// // Determine base variant name based on kernel type -// if (context.key.is_vec) { -// variant = "mul_mat_vec"; -// } else if (context.key.use_subgroup_matrix) { -// variant = "mul_mat_subgroup_matrix"; -// } else if (context.key.register_tile) { -// variant = "mul_mat_reg_tile"; -// } -// -// // Determine src0/src1 type strings -// const char * src0_type_str = nullptr; -// -// bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile; -// -// // src1 type -// switch (context.key.src1_type) { -// case GGML_TYPE_F32: -// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f32"); -// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); -// break; -// case GGML_TYPE_F16: -// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4" : "SRC1_TYPE=f16"); -// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4" : "DST_TYPE=f32"); -// break; -// default: -// break; -// } -// -// // same for all types -// defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4" : "SHMEM_TYPE=f16"); -// -// // src0 type -// const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type); -// src0_type_str = src0_type_traits->type_name; -// -// // for f32 and f16, account for vectorized src0 types -// switch (context.key.src0_type) { -// case GGML_TYPE_F32: -// src0_type_str = context.key.vectorized ? "vec4" : "f32"; -// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f32"); -// -// defines.push_back("FLOAT"); -// defines.push_back("MUL_ACC_FLOAT"); -// defines.push_back("INIT_SRC0_SHMEM_FLOAT"); -// defines.push_back("INIT_SRC1_SHMEM_FLOAT"); -// -// variant += "_f32"; -// break; -// -// case GGML_TYPE_F16: -// src0_type_str = context.key.vectorized ? "vec4" : "f16"; -// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4" : "SRC0_TYPE=f16"); -// -// defines.push_back("FLOAT"); -// defines.push_back("MUL_ACC_FLOAT"); -// defines.push_back("INIT_SRC0_SHMEM_FLOAT"); -// defines.push_back("INIT_SRC1_SHMEM_FLOAT"); -// -// variant += "_f16"; -// -// break; -// -// default: -// // convert name to upper case for other defines -// std::string type_upper = src0_type_str; -// std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); -// -// // push back defines for quantized types -// defines.push_back("BYTE_HELPERS"); -// defines.push_back(type_upper + "_T"); -// -// defines.push_back(type_upper); -// defines.push_back("MUL_ACC_" + type_upper); -// defines.push_back("INIT_SRC0_SHMEM_" + type_upper); -// defines.push_back("INIT_SRC1_SHMEM_FLOAT"); -// -// // for q4_k and q5_k -// defines.push_back(type_upper + "_SCALE_MIN"); -// -// // defines for i-quants -// defines.push_back(type_upper + "_TABLES"); -// defines.push_back(type_upper + "_GRID"); -// -// // add variant -// variant += "_"; -// variant += src0_type_str; -// -// // add define for non-fast path quantized src0 type-- overwritten if fast -// // path -// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); -// -// break; -// } -// -// // Add VEC/SCALAR defines -// if (is_fast_path) { -// // if quantized type and using fast path, need to use f16 instead of the -// // quantized type -// if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) { -// src0_type_str = "f16"; -// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str); -// } -// -// // all fast paths need VEC vs SCALAR -// defines.push_back(context.key.vectorized ? "VEC" : "SCALAR"); -// // add vec_size define too -// defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u"); -// -// // reg_tile and subgroup_matrix need these extra defines -// if (!context.key.is_vec) { -// defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR"); -// } -// } -// -// // Append src1 type -// variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16"); -// -// // printf("DEBUG: After appending src1 type: variant='%s'\n", -// // variant.c_str()); -// -// // Vectorized suffix -// if (context.key.vectorized) { -// variant += "_vec"; -// } -// -// // Add defines for TILE_M and TILE_N -// defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); -// defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); -// -// // Add subgroup matrix defines if using subgroup_matrix -// if (context.key.use_subgroup_matrix) { -// defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); -// defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); -// defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); -// defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); -// defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); -// defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); -// defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); -// defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); -// defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); -// } -// -// ggml_webgpu_processed_shader result; -// result.wgsl = preprocessor.preprocess(shader_src, defines); -// result.variant = variant; -// -// auto decisions = std::make_shared(); -// decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; -// decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; -// decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K; -// decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; -// decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; -// decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; -// decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; -// decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; -// decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; -// decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; -// decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; -// decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; -// decisions->is_vec = context.key.is_vec; -// decisions->use_subgroup_matrix = context.key.use_subgroup_matrix; -// result.decisions = decisions; -// -// return result; -//} - #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 377972e1f6..d8782f2408 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -354,18 +354,9 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - pre_wgsl::Preprocessor p; - webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - std::unordered_map - mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - std::map> cpy_pipelines; // src_type, dst_type std::map rms_norm_pipelines; // inplace @@ -414,25 +405,6 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ -// Process a WGSL shader string, replacing tokens of the form {{KEY}} with -// the corresponding values provided in `repls`. -static std::string ggml_webgpu_process_shader_repls(const char * src, - const std::map & repls) { - if (!src) { - return std::string(); - } - std::string s = src; - for (const auto & kv : repls) { - std::string token = "{{" + kv.first + "}}"; - size_t pos = 0; - while ((pos = s.find(token, pos)) != std::string::npos) { - s.replace(pos, token.length(), kv.second); - pos += kv.second.length(); - } - } - return s; -} - static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, const char * shader_code, const char * label, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 6827995bf7..efa743197f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -1,4 +1,10 @@ -#ifdef SHMEM_VEC +#ifdef VEC +#define VEC_SIZE 4 +#define SHMEM_TYPE vec4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + fn store_shmem(val: vec4, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; @@ -7,7 +13,13 @@ fn store_shmem(val: vec4, idx: u32) { } #endif -#ifdef SHMEM_SCALAR +#ifdef SCALAR +#define VEC_SIZE 1 +#define SHMEM_TYPE f16 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index bbce0de9a1..f9ea95e07b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -4,6 +4,12 @@ enable f16; #include "common_decls.tmpl" #ifdef VEC + +#define VEC_SIZE 4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(dot(SRC1_TYPE(src0_val), src1_val)); } @@ -17,6 +23,12 @@ fn store_val(group_base: u32) -> vec4 { #endif #ifdef SCALAR + +#define VEC_SIZE 1 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); }