From c10219705defef07c8c49622b17ea279f8cbbbb6 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 8 Sep 2025 10:15:21 -0700 Subject: [PATCH] Get addition and multiplication working --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 116 ++++++++++++++---- .../ggml-webgpu/wgsl-shaders/add.tmpl.wgsl | 62 +--------- .../wgsl-shaders/add_in_place.tmpl.wgsl | 41 +++++++ .../ggml-webgpu/wgsl-shaders/binary_head.tmpl | 45 +++++++ .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 26 +++- .../ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl | 44 +++++++ .../wgsl-shaders/mul_in_place.tmpl.wgsl | 41 +++++++ 7 files changed, 292 insertions(+), 83 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4bc011729e..2669a3f95c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -126,6 +126,9 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline cpy_pipeline; wgpu::ComputePipeline add_pipeline[2]; + wgpu::ComputePipeline add_ip_pipeline[2]; + wgpu::ComputePipeline mul_pipeline[2]; + wgpu::ComputePipeline mul_ip_pipeline[2]; size_t memset_bytes_per_thread; @@ -347,7 +350,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, - bool submit_and_wait = false) { + const char * bind_group_label = nullptr, + bool submit_and_wait = false) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -368,6 +372,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & bind_group_desc.layout = pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = bind_group_entries.size(); bind_group_desc.entries = bind_group_entries.data(); + if (bind_group_label) { + bind_group_desc.label = bind_group_label; + } wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); @@ -413,7 +420,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, }; size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true); } /** End WebGPU Actions */ @@ -457,6 +464,12 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); } +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -485,7 +498,7 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { @@ -537,7 +550,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t std::lock_guard lock(ctx->mutex); ctx->staged_set_row_error_bufs.push_back(error_bufs); - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -577,10 +590,16 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x, + ggml_op_name(dst->op)); } -static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + wgpu::ComputePipeline & pipeline, + bool in_place) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -607,16 +626,18 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso { .binding = 1, .buffer = ggml_webgpu_tensor_buf(src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + .size = ggml_webgpu_tensor_binding_size(ctx, src1) } }; + if (!in_place) { + entries.push_back({ .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline[dst->type], params, entries, wg_x); + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } // Returns true if node has enqueued work into the queue, false otherwise @@ -654,7 +675,20 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { } case GGML_OP_ADD: { - ggml_webgpu_add(ctx, src0, src1, node); + if (ggml_webgpu_tensor_equal(src0, node)) { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); + } else { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); + } + break; + } + case GGML_OP_MUL: + { + if (ggml_webgpu_tensor_equal(src0, node)) { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); + } else { + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); + } break; } default: @@ -994,8 +1028,28 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants(1); constants[0].key = "wg_size"; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, + "add_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, + "add_in_place_f16", constants); +} + +static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, + "mul_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, + "mul_in_place_f16", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -1048,22 +1102,30 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { GGML_UNUSED(dev); + bool supports_op = false; switch (op->op) { case GGML_OP_NONE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_RESHAPE: + supports_op = true; + break; case GGML_OP_ADD: - return true; + case GGML_OP_MUL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && + (op->src[1]->type == op->type); + break; case GGML_OP_CPY: case GGML_OP_SET_ROWS: - return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; + supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32); + break; case GGML_OP_MUL_MAT: { switch (op->src[1]->type) { case GGML_TYPE_F16: - return op->src[0]->type == GGML_TYPE_F16; + supports_op = (op->src[0]->type == GGML_TYPE_F16); + break; case GGML_TYPE_F32: switch (op->src[0]->type) { case GGML_TYPE_F32: @@ -1087,17 +1149,26 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: - return true; + supports_op = true; + break; default: - return false; + break; } default: - return false; + break; } } default: - return false; + break; } +#ifdef GGML_WEBGPU_DEBUG + if (!supports_op) { + WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) + << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") + << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); + } +#endif + return supports_op; } static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { @@ -1210,6 +1281,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_set_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); + ggml_webgpu_init_mul_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl index b888c3d10b..f261cbb553 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl @@ -19,6 +19,8 @@ enable f16; +#include "binary_head.tmpl" + @group(0) @binding(0) var src0: array<{{TYPE}}>; @@ -28,71 +30,15 @@ var src1: array<{{TYPE}}>; @group(0) @binding(2) var dst: array<{{TYPE}}>; -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - @group(0) @binding(3) var params: Params; override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; + if (gid.x < params.ne) { + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; } - - // i = thread id, ranges from 0 --> total ne - 1 - // represents the position in the flat array a we are adding with array b - var i = gid.x; - - // given the index of linear a, we want to compute the 4d index [a_i0, a_i1, a_i2, a_i3] - // we need this because tensor a and b are different shapes - // so the same linear index won't work for b, and we can only compute b's linear index from the 4d index of a - - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - let src1_idx = b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; - - // actual addition operation, now that the indexes are all figured out - // ensuring that the offsets are included - // gid.x used for flat indexing into dst and a, since variable i was modified during calcs - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx]; } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl new file mode 100644 index 0000000000..903f7bdbcc --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl @@ -0,0 +1,41 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl new file mode 100644 index 0000000000..4b254f468d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl @@ -0,0 +1,45 @@ +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 1e518ec118..a9e73ed295 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -26,6 +26,24 @@ def replace_placeholders(shader_text, replacements): shader_text = re.sub(pattern, str(val), shader_text) return shader_text +def expand_includes(shader, input_dir): + """ + Replace #include "file" lines in the text with the contents of that file. + Searches for files relative to input_dir. + """ + include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE) + + def replacer(match): + fname = match.group(1) + file_path = os.path.join(input_dir, fname) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Included file not found: {file_path}") + with open(file_path, "r", encoding="utf-8") as f: + included_code = f.read() + # Recursively expand includes inside the included file + return expand_includes(included_code, input_dir) + + return include_pattern.sub(replacer, shader) def write_shader(shader_name, shader_code, output_dir, outfile): if output_dir: @@ -35,8 +53,9 @@ def write_shader(shader_name, shader_code, output_dir, outfile): outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') -def generate_variants(shader_path, output_dir, outfile): - shader_base_name = shader_path.split("/")[-1].split(".")[0] +def generate_variants(fname, input_dir, output_dir, outfile): + shader_path = os.path.join(input_dir, fname) + shader_base_name = fname.split(".")[0] with open(shader_path, "r", encoding="utf-8") as f: text = f.read() @@ -52,6 +71,7 @@ def generate_variants(shader_path, output_dir, outfile): decls_map = {} shader_template = extract_block(text, "SHADER") + shader_template = expand_includes(shader_template, input_dir) for variant in variants: if "DECLS" in variant: decls = variant["DECLS"] @@ -89,7 +109,7 @@ def main(): out.write("// Auto-generated shader embedding\n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out) + generate_variants(fname, args.input_dir, args.output_dir, out) if __name__ == "__main__": diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl new file mode 100644 index 0000000000..12506e1420 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl @@ -0,0 +1,44 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl new file mode 100644 index 0000000000..e467e59edb --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl @@ -0,0 +1,41 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; + } +} + +#end(SHADER)