Start work on all-encompassing shader library

This commit is contained in:
Reese Levine 2026-02-10 20:30:19 -08:00
parent b3927f807d
commit 75e66cb49d
2 changed files with 69 additions and 22 deletions

View File

@ -3,6 +3,9 @@
#include "ggml.h"
#include "pre_wgsl.hpp"
#include "ggml-wgsl-shaders.hpp"
#include <webgpu/webgpu_cpp.h>
#include <algorithm>
#include <memory>
@ -47,6 +50,64 @@
#define WEBGPU_MAX_WG_SIZE 288
#define WEBGPU_MUL_MAT_WG_SIZE 256
struct ggml_webgpu_shader_lib_context {
ggml_tensor * src0;
ggml_tensor * src1;
ggml_tensor * dst;
uint32_t max_wg_size;
};
struct webgpu_pipeline {
wgpu::ComputePipeline pipeline;
std::string name;
std::shared_ptr<void> context = nullptr;
};
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
webgpu_pipeline get_sum_rows_pipeline(ggml_webgpu_shader_lib_context & context) {
auto it = sum_rows_pipelines.find(1);
if (it != sum_rows_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_sum_rows, defines);
sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
return sum_rows_pipelines[1];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,
std::string label) {
wgpu::ShaderSourceWGSL shader_source;
shader_source.code = shader_code.c_str();
wgpu::ShaderModuleDescriptor shader_desc;
shader_desc.nextInChain = &shader_source;
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
wgpu::ComputePipelineDescriptor pipeline_desc;
pipeline_desc.label = label.c_str();
pipeline_desc.compute.module = shader_module;
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
pipeline_desc.layout = nullptr; // nullptr means auto layout
return { device.CreateComputePipeline(&pipeline_desc), label };
}
};
// helper function for replacing {{PLACEHOLDERS}}
inline void ggml_webgpu_replace_placeholder(std::string & shader_code,
const std::string & key,

View File

@ -8,7 +8,6 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "ggml-wgsl-shaders.hpp"
#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
@ -23,6 +22,7 @@
#include <cstring>
#include <iostream>
#include <map>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
@ -268,12 +268,6 @@ struct webgpu_gpu_profile_buf_pool {
};
#endif
struct webgpu_pipeline {
wgpu::ComputePipeline pipeline;
std::string name;
std::shared_ptr<void> context = nullptr;
};
struct webgpu_command {
wgpu::CommandBuffer commands;
std::vector<webgpu_pool_bufs> params_bufs;
@ -358,6 +352,8 @@ struct webgpu_context_struct {
// Points to global instances owned by ggml_backend_webgpu_reg_context
webgpu_global_context global_ctx;
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
pre_wgsl::Preprocessor p;
webgpu_buf_pool param_buf_pool;
@ -377,7 +373,6 @@ struct webgpu_context_struct {
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
@ -2215,22 +2210,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
.vec4 = false,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
auto it = ctx->sum_rows_pipelines.find(1);
if (it != ctx->sum_rows_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->sum_rows_pipelines.emplace(1, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -2965,6 +2950,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);