Start work on all-encompassing shader library
This commit is contained in:
parent
b3927f807d
commit
75e66cb49d
|
|
@ -3,6 +3,9 @@
|
|||
|
||||
#include "ggml.h"
|
||||
#include "pre_wgsl.hpp"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
|
@ -47,6 +50,64 @@
|
|||
#define WEBGPU_MAX_WG_SIZE 288
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
|
||||
struct ggml_webgpu_shader_lib_context {
|
||||
ggml_tensor * src0;
|
||||
ggml_tensor * src1;
|
||||
ggml_tensor * dst;
|
||||
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
struct webgpu_pipeline {
|
||||
wgpu::ComputePipeline pipeline;
|
||||
std::string name;
|
||||
std::shared_ptr<void> context = nullptr;
|
||||
};
|
||||
|
||||
class ggml_webgpu_shader_lib {
|
||||
wgpu::Device device;
|
||||
pre_wgsl::Preprocessor preprocessor;
|
||||
|
||||
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
|
||||
webgpu_pipeline get_sum_rows_pipeline(ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = sum_rows_pipelines.find(1);
|
||||
if (it != sum_rows_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_sum_rows, defines);
|
||||
sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
|
||||
return sum_rows_pipelines[1];
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||
std::string shader_code,
|
||||
std::string label) {
|
||||
wgpu::ShaderSourceWGSL shader_source;
|
||||
shader_source.code = shader_code.c_str();
|
||||
|
||||
wgpu::ShaderModuleDescriptor shader_desc;
|
||||
shader_desc.nextInChain = &shader_source;
|
||||
|
||||
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
|
||||
|
||||
wgpu::ComputePipelineDescriptor pipeline_desc;
|
||||
pipeline_desc.label = label.c_str();
|
||||
pipeline_desc.compute.module = shader_module;
|
||||
pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
|
||||
pipeline_desc.layout = nullptr; // nullptr means auto layout
|
||||
return { device.CreateComputePipeline(&pipeline_desc), label };
|
||||
}
|
||||
};
|
||||
|
||||
// helper function for replacing {{PLACEHOLDERS}}
|
||||
inline void ggml_webgpu_replace_placeholder(std::string & shader_code,
|
||||
const std::string & key,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
|
|
@ -23,6 +22,7 @@
|
|||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
|
@ -268,12 +268,6 @@ struct webgpu_gpu_profile_buf_pool {
|
|||
};
|
||||
#endif
|
||||
|
||||
struct webgpu_pipeline {
|
||||
wgpu::ComputePipeline pipeline;
|
||||
std::string name;
|
||||
std::shared_ptr<void> context = nullptr;
|
||||
};
|
||||
|
||||
struct webgpu_command {
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
|
|
@ -358,6 +352,8 @@ struct webgpu_context_struct {
|
|||
// Points to global instances owned by ggml_backend_webgpu_reg_context
|
||||
webgpu_global_context global_ctx;
|
||||
|
||||
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
|
||||
|
||||
pre_wgsl::Preprocessor p;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
|
|
@ -377,7 +373,6 @@ struct webgpu_context_struct {
|
|||
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
|
||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
|
|
@ -2215,22 +2210,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
|
||||
.vec4 = false,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->sum_rows_pipelines.find(1);
|
||||
if (it != ctx->sum_rows_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows");
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
ctx->sum_rows_pipelines.emplace(1, pipeline);
|
||||
}
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
|
||||
|
||||
uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
|
@ -2965,6 +2950,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||
ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context;
|
||||
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
||||
webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx;
|
||||
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
|
||||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
|
||||
|
|
|
|||
Loading…
Reference in New Issue