This commit is contained in:
Chen Yuan 2026-04-16 15:50:41 +00:00 committed by GitHub
commit a984d87457
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 338 additions and 11 deletions

View File

@ -248,6 +248,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key {
}
};
/** CONV 2D */
struct ggml_webgpu_conv2d_pipeline_key {
ggml_type weight_type;
ggml_type input_type;
ggml_type output_type;
bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
}
};
struct ggml_webgpu_conv2d_pipeline_key_hash {
size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.weight_type);
ggml_webgpu_hash_combine(seed, key.input_type);
ggml_webgpu_hash_combine(seed, key.output_type);
return seed;
}
};
/** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key {
int type;
@ -831,6 +852,8 @@ class ggml_webgpu_shader_lib {
rope_pipelines;
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
soft_max_pipelines;
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
conv2d_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -1115,8 +1138,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (key.src_type)
{
switch (key.src_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@ -1136,9 +1158,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
}
defines.push_back("BYTE_HELPERS");
@ -1621,8 +1643,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (context.src0->type)
{
switch (context.src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@ -1642,9 +1663,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}
defines.push_back("BYTE_HELPERS");
@ -2340,6 +2361,47 @@ class ggml_webgpu_shader_lib {
return soft_max_pipelines[key];
}
webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_conv2d_pipeline_key key = {
.weight_type = context.src0->type,
.input_type = context.src1->type,
.output_type = context.dst->type,
};
auto it = conv2d_pipelines.find(key);
if (it != conv2d_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "conv_2d";
auto push_type_defines = [&](const char * prefix, ggml_type type) {
std::string s_prefix = prefix;
if (type == GGML_TYPE_F32) {
defines.push_back(s_prefix + "_F32");
} else if (type == GGML_TYPE_F16) {
defines.push_back(s_prefix + "_F16");
} else {
GGML_ABORT("Unsupported type for CONV_2D shader");
}
};
push_type_defines("WEIGHT", key.weight_type);
push_type_defines("INPUT", key.input_type);
push_type_defines("OUTPUT", key.output_type);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
conv2d_pipelines[key] = pipeline;
return conv2d_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,

View File

@ -8,6 +8,7 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "ggml.h"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@ -83,7 +84,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256 // enough for 64 parameters
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
@ -907,6 +908,97 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
(uint32_t) s0,
(uint32_t) s1,
(uint32_t) p0,
(uint32_t) p1,
(uint32_t) d0,
(uint32_t) d1,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
};
uint32_t max_wg_size =
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
uint32_t wg_size =
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = wg_size,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t n_out = ggml_nelements(dst);
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
uint32_t wg_x = std::min(total_wg, max_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
@ -2753,6 +2845,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context c
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
return ggml_webgpu_sum_rows(ctx, src0, node);
case GGML_OP_CONV_2D:
return ggml_webgpu_conv_2d(ctx, src0, src1, node);
default:
return std::nullopt;
}
@ -3781,6 +3875,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SOLVE_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
break;
case GGML_OP_CONV_2D:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;

View File

@ -0,0 +1,166 @@
#include "common_decls.tmpl"
enable f16;
@group(0) @binding(0)
#if defined(WEIGHT_F32)
var<storage, read_write> weights: array<f32>;
#elif defined(WEIGHT_F16)
var<storage, read_write> weights: array<f16>;
#endif
@group(0) @binding(1)
#if defined(INPUT_F32)
var<storage, read_write> input: array<f32>;
#elif defined(INPUT_F16)
var<storage, read_write> input: array<f16>;
#endif
@group(0) @binding(2)
#if defined(OUTPUT_F32)
var<storage, read_write> output: array<f32>;
#elif defined(OUTPUT_F16)
var<storage, read_write> output: array<f16>;
#endif
struct Params {
offset_w: u32,
offset_i: u32,
offset_o: u32,
// element strides
sw0: u32, sw1: u32, sw2: u32, sw3: u32,
si0: u32, si1: u32, si2: u32, si3: u32,
so0: u32, so1: u32, so2: u32, so3: u32,
// kernel dimensions
KW: u32, KH: u32, IC: u32,
// input dimensions
IW: u32, IH: u32,
// output dimensions
OW: u32, OH: u32, OC_out: u32, N_out: u32,
// stride
s0: u32, s1: u32,
// padding
p0: u32, p1: u32,
// dilation
d0: u32, d1: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
fn load_weight(idx: u32) -> f32 {
#if defined(WEIGHT_F32)
return weights[idx];
#elif defined(WEIGHT_F16)
return f32(weights[idx]);
#endif
}
fn load_input(idx: u32) -> f32 {
#if defined(INPUT_F32)
return input[idx];
#elif defined(INPUT_F16)
return f32(input[idx]);
#endif
}
fn store_output(idx: u32, val: f32) {
#if defined(OUTPUT_F32)
output[idx] = val;
#elif defined(OUTPUT_F16)
output[idx] = f16(val);
#endif
}
fn ceil_div_u32(x: u32, y: u32) -> u32 {
return (x + y - 1) / y;
}
// returns the first valid kernel index k such that base + k * step >= 0
fn first_valid_k(base: i32, step: u32) -> u32 {
if (base >= 0) {
return 0;
}
return ceil_div_u32(u32(-base), step);
}
// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k)
fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 {
let remaining = i32(limit) - base;
if (remaining <= 0) {
return 0;
}
return min(k_max, ceil_div_u32(u32(remaining), step));
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let threads_per_group = u32(WG_SIZE);
let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
let n_out = params.OW * params.OH * params.OC_out * params.N_out;
var sum: f32 = 0.0;
if (i_out >= n_out) {
return;
}
// Kernel layout: [KW, KH, IC, ..]
// Input layout: [IW, IH, .., ..]
// Output layout: [OW, OH, OC, N]
var i = i_out;
let n = i / (params.OC_out * params.OH * params.OW);
i = i % (params.OC_out * params.OH * params.OW);
let oc = i / (params.OH * params.OW);
i = i % (params.OH * params.OW);
let oh = i / params.OW;
let ow = i % params.OW;
let ow_base = i32(ow * params.s0) - i32(params.p0);
let oh_base = i32(oh * params.s1) - i32(params.p1);
// clip the valid kernel window once
let kw_begin = first_valid_k(ow_base, params.d0);
let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW);
let kh_begin = first_valid_k(oh_base, params.d1);
let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH);
// entire receptive field is out of bounds
if (kw_begin >= kw_end || kh_begin >= kh_end) {
let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
store_output(out_idx, 0.0);
return;
}
let weight_oc_base = params.offset_w + oc * params.sw3;
let input_n_base = params.offset_i + n * params.si3;
for (var ic: u32 = 0; ic < params.IC; ic += 1) {
let w_base_ic = ic * params.sw2 + weight_oc_base;
let in_base = ic * params.si2 + input_n_base;
for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) {
let ih = u32(oh_base + i32(kh * params.d1));
let w_row_base = w_base_ic + kh * params.sw1;
let in_row_base = in_base + ih * params.si1;
for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) {
let iw = u32(ow_base + i32(kw * params.d0));
let w_idx = w_row_base + kw * params.sw0;
let in_idx = in_row_base + iw * params.si0;
sum += load_weight(w_idx) * load_input(in_idx);
}
}
}
let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
store_output(out_idx, sum);
}