Get addition and multiplication working

This commit is contained in:
Reese Levine 2025-09-08 10:15:21 -07:00
parent 7f9ee10e75
commit c10219705d
7 changed files with 292 additions and 83 deletions

View File

@ -126,6 +126,9 @@ struct webgpu_context_struct {
wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline cpy_pipeline; wgpu::ComputePipeline cpy_pipeline;
wgpu::ComputePipeline add_pipeline[2]; wgpu::ComputePipeline add_pipeline[2];
wgpu::ComputePipeline add_ip_pipeline[2];
wgpu::ComputePipeline mul_pipeline[2];
wgpu::ComputePipeline mul_ip_pipeline[2];
size_t memset_bytes_per_thread; size_t memset_bytes_per_thread;
@ -347,6 +350,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
std::vector<uint32_t> params, std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries, std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x, uint32_t wg_x,
const char * bind_group_label = nullptr,
bool submit_and_wait = false) { bool submit_and_wait = false) {
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
@ -368,6 +372,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
bind_group_desc.layout = pipeline.GetBindGroupLayout(0); bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = bind_group_entries.size(); bind_group_desc.entryCount = bind_group_entries.size();
bind_group_desc.entries = bind_group_entries.data(); bind_group_desc.entries = bind_group_entries.data();
if (bind_group_label) {
bind_group_desc.label = bind_group_label;
}
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
@ -413,7 +420,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
}; };
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
} }
/** End WebGPU Actions */ /** End WebGPU Actions */
@ -457,6 +464,12 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
} }
// Used to determine if two tensors are the same for in-place operations
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t ne = (uint32_t) ggml_nelements(dst);
@ -485,7 +498,7 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
} }
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@ -537,7 +550,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
std::lock_guard<std::recursive_mutex> lock(ctx->mutex); std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ctx->staged_set_row_error_bufs.push_back(error_bufs); ctx->staged_set_row_error_bufs.push_back(error_bufs);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
} }
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@ -577,10 +590,16 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
uint32_t wg_x = uint32_t wg_x =
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
ggml_op_name(dst->op));
} }
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { static void ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
wgpu::ComputePipeline & pipeline,
bool in_place) {
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst), (uint32_t) ggml_nelements(dst),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@ -607,16 +626,18 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
{ .binding = 1, { .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1), .buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }, .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
{ .binding = 2, };
if (!in_place) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst), .buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) } .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}; }
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline[dst->type], params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
} }
// Returns true if node has enqueued work into the queue, false otherwise // Returns true if node has enqueued work into the queue, false otherwise
@ -654,7 +675,20 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
} }
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
ggml_webgpu_add(ctx, src0, src1, node); if (ggml_webgpu_tensor_equal(src0, node)) {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
} else {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
}
break;
}
case GGML_OP_MUL:
{
if (ggml_webgpu_tensor_equal(src0, node)) {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
} else {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
}
break; break;
} }
default: default:
@ -994,8 +1028,28 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1); std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
"add_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
"add_in_place_f16", constants);
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
"mul_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
"mul_in_place_f16", constants);
} }
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
@ -1048,22 +1102,30 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_UNUSED(dev); GGML_UNUSED(dev);
bool supports_op = false;
switch (op->op) { switch (op->op) {
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
supports_op = true;
break;
case GGML_OP_ADD: case GGML_OP_ADD:
return true; case GGML_OP_MUL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
(op->src[1]->type == op->type);
break;
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
switch (op->src[1]->type) { switch (op->src[1]->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
return op->src[0]->type == GGML_TYPE_F16; supports_op = (op->src[0]->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
switch (op->src[0]->type) { switch (op->src[0]->type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -1087,17 +1149,26 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
return true; supports_op = true;
break;
default: default:
return false; break;
} }
default: default:
return false; break;
} }
} }
default: default:
return false; break;
} }
#ifdef GGML_WEBGPU_DEBUG
if (!supports_op) {
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
}
#endif
return supports_op;
} }
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
@ -1210,6 +1281,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ggml_webgpu_init_set_rows_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx);
ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx);
ggml_webgpu_init_add_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx);
ggml_webgpu_init_mul_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG #ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers // Initialize debug buffers

View File

@ -19,6 +19,8 @@
enable f16; enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>; var<storage, read_write> src0: array<{{TYPE}}>;
@ -28,71 +30,15 @@ var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2) @group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>; var<storage, read_write> dst: array<{{TYPE}}>;
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
@group(0) @binding(3) @group(0) @binding(3)
var<uniform> params: Params; var<uniform> params: Params;
override wg_size: u32; override wg_size: u32;
@compute @workgroup_size(wg_size) @compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) { fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) { if (gid.x < params.ne) {
return; dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
} }
// i = thread id, ranges from 0 --> total ne - 1
// represents the position in the flat array a we are adding with array b
var i = gid.x;
// given the index of linear a, we want to compute the 4d index [a_i0, a_i1, a_i2, a_i3]
// we need this because tensor a and b are different shapes
// so the same linear index won't work for b, and we can only compute b's linear index from the 4d index of a
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
let src1_idx = b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
// actual addition operation, now that the indexes are all figured out
// ensuring that the offsets are included
// gid.x used for flat indexing into dst and a, since variable i was modified during calcs
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx];
} }
#end(SHADER) #end(SHADER)

View File

@ -0,0 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -0,0 +1,45 @@
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
return b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
}

View File

@ -26,6 +26,24 @@ def replace_placeholders(shader_text, replacements):
shader_text = re.sub(pattern, str(val), shader_text) shader_text = re.sub(pattern, str(val), shader_text)
return shader_text return shader_text
def expand_includes(shader, input_dir):
"""
Replace #include "file" lines in the text with the contents of that file.
Searches for files relative to input_dir.
"""
include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
def replacer(match):
fname = match.group(1)
file_path = os.path.join(input_dir, fname)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Included file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
included_code = f.read()
# Recursively expand includes inside the included file
return expand_includes(included_code, input_dir)
return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile): def write_shader(shader_name, shader_code, output_dir, outfile):
if output_dir: if output_dir:
@ -35,8 +53,9 @@ def write_shader(shader_name, shader_code, output_dir, outfile):
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
def generate_variants(shader_path, output_dir, outfile): def generate_variants(fname, input_dir, output_dir, outfile):
shader_base_name = shader_path.split("/")[-1].split(".")[0] shader_path = os.path.join(input_dir, fname)
shader_base_name = fname.split(".")[0]
with open(shader_path, "r", encoding="utf-8") as f: with open(shader_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
@ -52,6 +71,7 @@ def generate_variants(shader_path, output_dir, outfile):
decls_map = {} decls_map = {}
shader_template = extract_block(text, "SHADER") shader_template = extract_block(text, "SHADER")
shader_template = expand_includes(shader_template, input_dir)
for variant in variants: for variant in variants:
if "DECLS" in variant: if "DECLS" in variant:
decls = variant["DECLS"] decls = variant["DECLS"]
@ -89,7 +109,7 @@ def main():
out.write("// Auto-generated shader embedding\n\n") out.write("// Auto-generated shader embedding\n\n")
for fname in sorted(os.listdir(args.input_dir)): for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"): if fname.endswith(".wgsl"):
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out) generate_variants(fname, args.input_dir, args.output_dir, out)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,44 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -0,0 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)