From e1f6baea31645e5d96ad53664acae856f74b96f4 Mon Sep 17 00:00:00 2001 From: James Contini Date: Wed, 29 Oct 2025 23:08:37 -0700 Subject: [PATCH] implemented REPL_Template support and removed bug in unary operators kernel --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 16 + .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 688 +++++++++--------- 3 files changed, 371 insertions(+), 337 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 2f4fdc1c3c..0fc2691cc4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -865,7 +865,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * dst, webgpu_pipeline & pipeline, bool in_place, - const std::vector & xielu_params = {}) { + const std::vector & extra_params = {}) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -881,7 +881,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; - params.insert(params.end(), xielu_params.begin(), xielu_params.end()); + params.insert(params.end(), extra_params.begin(), extra_params.end()); std::vector entries = { { .binding = 0, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 251051eaec..7de19bef77 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -18,6 +18,14 @@ def parse_decls(decls_text): decls[name.strip()] = code.strip() return decls +def replace_repl_placeholders(variant, template_map): + + for repl, code in variant["REPLS"].items(): + for key, val in template_map.items(): + # Match "key" and avoid matching subsequences using by using \b + code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) + variant["REPLS"][repl] = code + return variant def replace_placeholders(shader_text, replacements): for key, val in replacements.items(): @@ -71,6 +79,10 @@ def generate_variants(fname, input_dir, output_dir, outfile): decls_map = parse_decls(extract_block(text, "DECLS")) except ValueError: decls_map = {} + try: + templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) + except ValueError: + templates_map = {} with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: common_decls = f.read() @@ -85,11 +97,15 @@ def generate_variants(fname, input_dir, output_dir, outfile): decls_code = "" for key in decls: if key not in decls_map: + raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) + if "REPLS" in variant: + variant = replace_repl_placeholders(variant, templates_map) + final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index 7f632a24e5..93f0fac66f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -1,342 +1,363 @@ +#define(REPL_TEMPLATES) + +{ + "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);", + "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);", + "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);", + "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];", + "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));", + "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", + "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", + "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", + "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", + "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", + "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);", + "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", + "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" +} + +#end(REPL_TEMPLATES) + #define(VARIANTS) [ - { - "SHADER_NAME": "abs_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "abs_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sgn_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sgn_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "neg_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "neg_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "step_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "step_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "tanh_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "tanh_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "elu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "elu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "relu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "relu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "silu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "silu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardswish_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardswish_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "exp_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "exp_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "xielu_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" + { + "SHADER_NAME": "abs_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" + { + "SHADER_NAME": "abs_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" + { + "SHADER_NAME": "abs_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "xielu_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" + { + "SHADER_NAME": "abs_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] }, - "DECLS": ["INPLACE"] - } + + { + "SHADER_NAME": "sgn_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sgn_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "neg_f32", + "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_f16", + "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "neg_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "step_f32", + "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_f16", + "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "step_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "tanh_f32", + "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_f16", + "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "tanh_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "elu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "elu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "relu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "relu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "sigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "silu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "silu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "exp_f32", + "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_f16", + "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "exp_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "hardsigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "hardswish_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardswish_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "gelu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "gelu_quick_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "xielu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "xielu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "xielu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "xielu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + } ] #end(VARIANTS) @@ -346,9 +367,6 @@ #decl(INPLACE) @group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) var params: Params; #enddecl(INPLACE)