All operators (inlcluding xielu) working

This commit is contained in:
James Contini 2025-10-12 13:32:45 -07:00
parent 74c6add176
commit 4cf28d7dec
2 changed files with 287 additions and 106 deletions

View File

@ -262,7 +262,7 @@ struct webgpu_context_struct {
webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split
webgpu_pipeline scale_pipeline[2]; // inplace
webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
webgpu_pipeline unary_pipeline[16][2][2];
webgpu_pipeline unary_pipeline[GGML_UNARY_OP_COUNT][2][2];
size_t memset_bytes_per_thread;
@ -344,6 +344,8 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
pipeline_desc.compute.constants = constants.data();
pipeline_desc.compute.constantCount = constants.size();
}
pipeline = { device.CreateComputePipeline(&pipeline_desc), label };
}
@ -867,7 +869,8 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * dst,
webgpu_pipeline & pipeline,
bool in_place) {
bool in_place,
bool additional_params=false) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
@ -885,6 +888,11 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx,
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
};
if (additional_params) {
for (uint i = 1; i < 5; i++) {
params.push_back((uint32_t)(ggml_get_op_params_f32(dst, i))); // alpha_n, alpha_p, beta, eps
}
}
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@ -1302,8 +1310,10 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_UNARY:
{
const ggml_unary_op UNARY_OP = ggml_get_unary_op(node);
int in_place = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place);
bool XIELU = (UNARY_OP == GGML_UNARY_OP_XIELU);
return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place, XIELU);
}
default:
@ -2023,6 +2033,16 @@ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1],
wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants);
// XIELU
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0],
wgsl_xielu_f32, "xielu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0],
wgsl_xielu_f16, "xielu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1],
wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1],
wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants);
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
@ -2254,9 +2274,36 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
supports_op = op->type == GGML_TYPE_F32;
break;
case GGML_OP_UNARY:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
{
const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
switch (UNARY_OP) {
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_XIELU:
supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
(src1 ? (src1->type == op->type) : true);
break;
case GGML_UNARY_OP_COUNT:
default:
break;
}
}
break;
default:
break;
}

View File

@ -4,22 +4,23 @@
{
"SHADER_NAME": "abs_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "abs_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "abs_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "abs_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sgn_f32",
@ -27,7 +28,7 @@
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sgn_f16",
@ -35,7 +36,7 @@
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sgn_in_place_f32",
@ -43,7 +44,7 @@
"TYPE": "f32",
"FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sgn_in_place_f16",
@ -51,27 +52,27 @@
"TYPE": "f16",
"FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "neg_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "neg_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "neg_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "neg_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "step_f32",
@ -79,7 +80,7 @@
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "step_f16",
@ -87,7 +88,7 @@
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "step_in_place_f32",
@ -95,7 +96,7 @@
"TYPE": "f32",
"FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "step_in_place_f16",
@ -103,27 +104,27 @@
"TYPE": "f16",
"FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "tanh_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
"DECLS": ["NOT_INPLACE"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "tanh_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
"DECLS": ["NOT_INPLACE"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "tanh_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
"DECLS": ["INPLACE"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "tanh_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
"DECLS": ["INPLACE"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "elu_f32",
@ -131,7 +132,7 @@
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "elu_f16",
@ -139,7 +140,7 @@
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "elu_in_place_f32",
@ -147,7 +148,7 @@
"TYPE": "f32",
"FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "elu_in_place_f16",
@ -155,7 +156,7 @@
"TYPE": "f16",
"FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "relu_f32",
@ -163,7 +164,7 @@
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "relu_f16",
@ -171,7 +172,7 @@
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "relu_in_place_f32",
@ -179,7 +180,7 @@
"TYPE": "f32",
"FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "relu_in_place_f16",
@ -187,179 +188,211 @@
"TYPE": "f16",
"FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sigmoid_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "sigmoid_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" },
"DECLS": ["NOT_INPLACE"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" },
"DECLS": ["NOT_INPLACE"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" },
"DECLS": ["INPLACE"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" },
"DECLS": ["INPLACE"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_quick_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" },
"DECLS": ["NOT_INPLACE"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_quick_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913)));" },
"DECLS": ["NOT_INPLACE"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_quick_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" },
"DECLS": ["INPLACE"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_quick_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i])));" },
"DECLS": ["INPLACE"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "silu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "silu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "silu_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "silu_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardswish_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardswish_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardswish_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardswish_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardsigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardsigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardsigmoid_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "hardsigmoid_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "exp_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "exp_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" },
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "exp_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "exp_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" },
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_erf_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
"FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_erf_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
"FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["NOT_INPLACE"]
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_erf_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
"FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "gelu_erf_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
"FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["INPLACE"]
"DECLS": ["INPLACE_DFLT_PARAMS"]
},
{
"SHADER_NAME": "xielu_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE_EXT_PARAMS"]
},
{
"SHADER_NAME": "xielu_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE_EXT_PARAMS"]
},
{
"SHADER_NAME": "xielu_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["INPLACE_EXT_PARAMS"]
},
{
"SHADER_NAME": "xielu_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE_EXT_PARAMS"]
}
]
@ -367,11 +400,7 @@
#define(DECLS)
#decl(NOT_INPLACE)
fn update(dst_i: u32, src_i: u32) {
{{FUNC}}
}
#decl(NOT_INPLACE_DFLT_PARAMS)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@ -379,26 +408,6 @@ var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(dst_i: u32, src_i: u32) {
{{FUNC}} // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458
}
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
@ -425,10 +434,135 @@ struct Params {
dst_ne2: u32
};
#enddecl(NOT_INPLACE_DFLT_PARAMS)
#decl(INPLACE_DFLT_PARAMS)
@group(0) @binding(1)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};
#enddecl(INPLACE_DFLT_PARAMS)
#decl(NOT_INPLACE_EXT_PARAMS)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
// XIELU params
alpha_n: u32,
alpha_p: u32,
beta: u32,
eps: u32
};
#enddecl(NOT_INPLACE_EXT_PARAMS)
#decl(INPLACE_EXT_PARAMS)
@group(0) @binding(1)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
// XIELU params
alpha_n: u32,
alpha_p: u32,
beta: u32,
eps: u32
};
#enddecl(INPLACE_EXT_PARAMS)
#end(DECLS)
#define(SHADER)
enable f16;
fn update(dst_i: u32, src_i: u32) {
{{FUNC}}
}
@group(0) @binding(0)
var<storage, read_write> src: array<{{TYPE}}>;
DECLS
override wg_size: u32;