diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 26bcf56b86..cf24659b23 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -3,6 +3,9 @@ #include "ggml.h" #include "pre_wgsl.hpp" +#include "ggml-wgsl-shaders.hpp" + +#include #include #include @@ -47,6 +50,64 @@ #define WEBGPU_MAX_WG_SIZE 288 #define WEBGPU_MUL_MAT_WG_SIZE 256 +struct ggml_webgpu_shader_lib_context { + ggml_tensor * src0; + ggml_tensor * src1; + ggml_tensor * dst; + + uint32_t max_wg_size; +}; + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; + std::shared_ptr context = nullptr; +}; + +class ggml_webgpu_shader_lib { + wgpu::Device device; + pre_wgsl::Preprocessor preprocessor; + + std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet + + public: + ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } + + webgpu_pipeline get_sum_rows_pipeline(ggml_webgpu_shader_lib_context & context) { + auto it = sum_rows_pipelines.find(1); + if (it != sum_rows_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_sum_rows, defines); + sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows"); + return sum_rows_pipelines[1]; + } + + private: + + static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + std::string shader_code, + std::string label) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label.c_str(); + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + return { device.CreateComputePipeline(&pipeline_desc), label }; + } +}; + // helper function for replacing {{PLACEHOLDERS}} inline void ggml_webgpu_replace_placeholder(std::string & shader_code, const std::string & key, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 745d2a5e88..a0bc1b125c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "ggml-wgsl-shaders.hpp" #include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ @@ -23,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -268,12 +268,6 @@ struct webgpu_gpu_profile_buf_pool { }; #endif -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; - std::shared_ptr context = nullptr; -}; - struct webgpu_command { wgpu::CommandBuffer commands; std::vector params_bufs; @@ -358,6 +352,8 @@ struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context webgpu_global_context global_ctx; + std::unique_ptr shader_lib; + pre_wgsl::Preprocessor p; webgpu_buf_pool param_buf_pool; @@ -377,7 +373,6 @@ struct webgpu_context_struct { std::unordered_map argsort_pipelines; // key is order (asc/desc) std::unordered_map argsort_merge_pipelines; // key is order (asc/desc) std::unordered_map cumsum_pipelines; // key is fixed, no variants yet - std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet std::unordered_map set_rows_pipelines; @@ -2215,22 +2210,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->sum_rows_pipelines.find(1); - if (it != ctx->sum_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->sum_rows_pipelines.emplace(1, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); + uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2965,6 +2950,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);