Get addition and multiplication working
This commit is contained in:
parent
7f9ee10e75
commit
c10219705d
|
|
@ -126,6 +126,9 @@ struct webgpu_context_struct {
|
|||
wgpu::ComputePipeline set_rows_pipeline;
|
||||
wgpu::ComputePipeline cpy_pipeline;
|
||||
wgpu::ComputePipeline add_pipeline[2];
|
||||
wgpu::ComputePipeline add_ip_pipeline[2];
|
||||
wgpu::ComputePipeline mul_pipeline[2];
|
||||
wgpu::ComputePipeline mul_ip_pipeline[2];
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
|
||||
|
|
@ -347,7 +350,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
|||
std::vector<uint32_t> params,
|
||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||
uint32_t wg_x,
|
||||
bool submit_and_wait = false) {
|
||||
const char * bind_group_label = nullptr,
|
||||
bool submit_and_wait = false) {
|
||||
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
||||
|
|
@ -368,6 +372,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
|||
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
|
||||
bind_group_desc.entryCount = bind_group_entries.size();
|
||||
bind_group_desc.entries = bind_group_entries.data();
|
||||
if (bind_group_label) {
|
||||
bind_group_desc.label = bind_group_label;
|
||||
}
|
||||
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
||||
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
|
|
@ -413,7 +420,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
|||
};
|
||||
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
|
||||
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
|
@ -457,6 +464,12 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
|
|||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||
}
|
||||
|
||||
// Used to determine if two tensors are the same for in-place operations
|
||||
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
|
|
@ -485,7 +498,7 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
|||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||
|
|
@ -537,7 +550,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
ctx->staged_set_row_error_bufs.push_back(error_bufs);
|
||||
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
|
@ -577,10 +590,16 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
|||
|
||||
uint32_t wg_x =
|
||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
wgpu::ComputePipeline & pipeline,
|
||||
bool in_place) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) ggml_nelements(dst),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
|
|
@ -607,16 +626,18 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
|
|||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||
};
|
||||
if (!in_place) {
|
||||
entries.push_back({ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline[dst->type], params, entries, wg_x);
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
|
|
@ -654,7 +675,20 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|||
}
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
ggml_webgpu_add(ctx, src0, src1, node);
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
@ -994,8 +1028,28 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
|||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
|
||||
"add_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
|
||||
"add_in_place_f16", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
|
||||
"mul_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
|
||||
"mul_in_place_f16", constants);
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
|
|
@ -1048,22 +1102,30 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
|
|||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
GGML_UNUSED(dev);
|
||||
|
||||
bool supports_op = false;
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RESHAPE:
|
||||
supports_op = true;
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
return true;
|
||||
case GGML_OP_MUL:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
|
||||
(op->src[1]->type == op->type);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
||||
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
supports_op = (op->src[0]->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
|
|
@ -1087,17 +1149,26 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return true;
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
if (!supports_op) {
|
||||
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
||||
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||
}
|
||||
#endif
|
||||
return supports_op;
|
||||
}
|
||||
|
||||
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
||||
|
|
@ -1210,6 +1281,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||
ggml_webgpu_init_add_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_pipeline(ctx);
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@
|
|||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
|
|
@ -28,71 +30,15 @@ var<storage, read_write> src1: array<{{TYPE}}>;
|
|||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
struct Params {
|
||||
ne: u32,
|
||||
|
||||
// offsets in elements
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
stride_src1_2: u32,
|
||||
stride_src1_3: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
a_ne1: u32,
|
||||
a_ne2: u32,
|
||||
|
||||
b_ne0: u32,
|
||||
b_ne1: u32,
|
||||
b_ne2: u32,
|
||||
b_ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
if (gid.x < params.ne) {
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
|
||||
// i = thread id, ranges from 0 --> total ne - 1
|
||||
// represents the position in the flat array a we are adding with array b
|
||||
var i = gid.x;
|
||||
|
||||
// given the index of linear a, we want to compute the 4d index [a_i0, a_i1, a_i2, a_i3]
|
||||
// we need this because tensor a and b are different shapes
|
||||
// so the same linear index won't work for b, and we can only compute b's linear index from the 4d index of a
|
||||
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne1 * params.a_ne0);
|
||||
let a_i1 = i / params.a_ne0;
|
||||
let a_i0 = i % params.a_ne0;
|
||||
|
||||
// handle repetition of b
|
||||
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||
let b_i0 = a_i0 % params.b_ne0;
|
||||
let b_i1 = a_i1 % params.b_ne1;
|
||||
let b_i2 = a_i2 % params.b_ne2;
|
||||
let b_i3 = a_i3 % params.b_ne3;
|
||||
|
||||
// compute index for position in b's flat array
|
||||
let src1_idx = b_i0 * params.stride_src1_0 +
|
||||
b_i1 * params.stride_src1_1 +
|
||||
b_i2 * params.stride_src1_2 +
|
||||
b_i3 * params.stride_src1_3;
|
||||
|
||||
// actual addition operation, now that the indexes are all figured out
|
||||
// ensuring that the offsets are included
|
||||
// gid.x used for flat indexing into dst and a, since variable i was modified during calcs
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx];
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
struct Params {
|
||||
ne: u32,
|
||||
|
||||
// offsets in elements
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
stride_src1_2: u32,
|
||||
stride_src1_3: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
a_ne1: u32,
|
||||
a_ne2: u32,
|
||||
|
||||
b_ne0: u32,
|
||||
b_ne1: u32,
|
||||
b_ne2: u32,
|
||||
b_ne3: u32,
|
||||
};
|
||||
|
||||
fn src1_index(_i: u32) -> u32 {
|
||||
var i = _i;
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne1 * params.a_ne0);
|
||||
let a_i1 = i / params.a_ne0;
|
||||
let a_i0 = i % params.a_ne0;
|
||||
|
||||
// handle repetition of b
|
||||
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||
let b_i0 = a_i0 % params.b_ne0;
|
||||
let b_i1 = a_i1 % params.b_ne1;
|
||||
let b_i2 = a_i2 % params.b_ne2;
|
||||
let b_i3 = a_i3 % params.b_ne3;
|
||||
|
||||
// compute index for position in b's flat array
|
||||
return b_i0 * params.stride_src1_0 +
|
||||
b_i1 * params.stride_src1_1 +
|
||||
b_i2 * params.stride_src1_2 +
|
||||
b_i3 * params.stride_src1_3;
|
||||
}
|
||||
|
|
@ -26,6 +26,24 @@ def replace_placeholders(shader_text, replacements):
|
|||
shader_text = re.sub(pattern, str(val), shader_text)
|
||||
return shader_text
|
||||
|
||||
def expand_includes(shader, input_dir):
|
||||
"""
|
||||
Replace #include "file" lines in the text with the contents of that file.
|
||||
Searches for files relative to input_dir.
|
||||
"""
|
||||
include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
|
||||
|
||||
def replacer(match):
|
||||
fname = match.group(1)
|
||||
file_path = os.path.join(input_dir, fname)
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"Included file not found: {file_path}")
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
included_code = f.read()
|
||||
# Recursively expand includes inside the included file
|
||||
return expand_includes(included_code, input_dir)
|
||||
|
||||
return include_pattern.sub(replacer, shader)
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||
if output_dir:
|
||||
|
|
@ -35,8 +53,9 @@ def write_shader(shader_name, shader_code, output_dir, outfile):
|
|||
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
||||
|
||||
|
||||
def generate_variants(shader_path, output_dir, outfile):
|
||||
shader_base_name = shader_path.split("/")[-1].split(".")[0]
|
||||
def generate_variants(fname, input_dir, output_dir, outfile):
|
||||
shader_path = os.path.join(input_dir, fname)
|
||||
shader_base_name = fname.split(".")[0]
|
||||
|
||||
with open(shader_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
|
@ -52,6 +71,7 @@ def generate_variants(shader_path, output_dir, outfile):
|
|||
decls_map = {}
|
||||
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
shader_template = expand_includes(shader_template, input_dir)
|
||||
for variant in variants:
|
||||
if "DECLS" in variant:
|
||||
decls = variant["DECLS"]
|
||||
|
|
@ -89,7 +109,7 @@ def main():
|
|||
out.write("// Auto-generated shader embedding\n\n")
|
||||
for fname in sorted(os.listdir(args.input_dir)):
|
||||
if fname.endswith(".wgsl"):
|
||||
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
|
||||
generate_variants(fname, args.input_dir, args.output_dir, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
Loading…
Reference in New Issue