responded and dealt with PR comments

This commit is contained in:
James Contini 2025-10-15 16:14:20 -07:00
parent f9282c660c
commit 8c70b8fece
2 changed files with 296 additions and 450 deletions

View File

@ -62,7 +62,6 @@
#define WEBGPU_MUL_MAT_WG_SIZE 256
#define WEBGPU_NUM_PARAM_BUFS 32u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
@ -252,19 +251,18 @@ struct webgpu_context_struct {
webgpu_pipeline set_rows_pipeline;
webgpu_pipeline get_rows_pipeline[30];
webgpu_pipeline get_rows_f32_no_vec_pipeline;
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
webgpu_pipeline add_pipeline[2][2]; // type, inplace
webgpu_pipeline sub_pipeline[2][2]; // type, inplace
webgpu_pipeline mul_pipeline[2][2]; // type, inplace
webgpu_pipeline div_pipeline[2][2]; // type, inplace
webgpu_pipeline rms_norm_pipeline[2]; // inplace
webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace
webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split
webgpu_pipeline scale_pipeline[2]; // inplace
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
webgpu_pipeline add_pipeline[2][2]; // type, inplace
webgpu_pipeline sub_pipeline[2][2]; // type, inplace
webgpu_pipeline mul_pipeline[2][2]; // type, inplace
webgpu_pipeline div_pipeline[2][2]; // type, inplace
webgpu_pipeline rms_norm_pipeline[2]; // inplace
webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace
webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split
webgpu_pipeline scale_pipeline[2]; // inplace
webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
webgpu_pipeline unary_pipeline[GGML_UNARY_OP_COUNT][2][2];
size_t memset_bytes_per_thread;
// Staging buffer for reading data from the GPU
@ -345,7 +343,6 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
pipeline_desc.compute.constantCount = constants.size();
}
pipeline = { device.CreateComputePipeline(&pipeline_desc), label };
}
@ -721,10 +718,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
size_t max_wg_size = ctx->max_wg_size_x;
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x);
}
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
@ -865,14 +860,12 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
}
static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * dst,
webgpu_pipeline & pipeline,
bool in_place,
bool additional_params=false) {
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * dst,
webgpu_pipeline & pipeline,
bool in_place,
const std::vector<uint32_t> & xielu_params = {}) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
@ -888,18 +881,13 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx,
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
};
if (additional_params) {
for (uint i = 1; i < 5; i++) {
params.push_back((uint32_t)(ggml_get_op_params_f32(dst, i))); // alpha_n, alpha_p, beta, eps
}
}
params.insert(params.end(), xielu_params.begin(), xielu_params.end());
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src),
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
};
if (!in_place) {
entries.push_back({ .binding = 1,
@ -1258,8 +1246,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
ggml_tensor * src1 = node->src[1];
ggml_tensor * src2 = node->src[2];
switch (node->op) {
// no-ops
case GGML_OP_NONE:
@ -1309,11 +1295,24 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
case GGML_OP_UNARY:
{
const ggml_unary_op UNARY_OP = ggml_get_unary_op(node);
const ggml_unary_op UNARY_OP = ggml_get_unary_op(node);
int in_place = ggml_webgpu_tensor_equal(src0, node);
std::vector<uint32_t> xielu_params;
int in_place = ggml_webgpu_tensor_equal(src0, node);
bool XIELU = (UNARY_OP == GGML_UNARY_OP_XIELU);
return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place, XIELU);
switch (UNARY_OP) {
case GGML_UNARY_OP_XIELU:
xielu_params = {
static_cast<uint32_t>(ggml_get_op_params_f32(node, 1)), // alpha_n
static_cast<uint32_t>(ggml_get_op_params_f32(node, 2)), // alpha_p
static_cast<uint32_t>(ggml_get_op_params_f32(node, 3)), // beta
static_cast<uint32_t>(ggml_get_op_params_f32(node, 4)) // eps
};
break;
default:
break;
}
return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place],
in_place, xielu_params);
}
default:
@ -1322,7 +1321,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
}
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
@ -1541,8 +1539,6 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
@ -1554,8 +1550,6 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
}
@ -1886,163 +1880,179 @@ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
// ABS
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0],
wgsl_abs_f32, "abs_f32", constants);
wgsl_abs_f32, "abs_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0],
wgsl_abs_f16, "abs_f16", constants);
wgsl_abs_f16, "abs_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1],
wgsl_abs_in_place_f32, "abs_in_place_f32", constants);
wgsl_abs_in_place_f32, "abs_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1],
wgsl_abs_in_place_f16, "abs_in_place_f16", constants);
wgsl_abs_in_place_f16, "abs_in_place_f16", constants);
// SGN
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0],
wgsl_sgn_f32, "sgn_f32", constants);
wgsl_sgn_f32, "sgn_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0],
wgsl_sgn_f16, "sgn_f16", constants);
wgsl_sgn_f16, "sgn_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1],
wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants);
wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1],
wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants);
wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants);
// NEG
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0],
wgsl_neg_f32, "neg_f32", constants);
wgsl_neg_f32, "neg_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0],
wgsl_neg_f16, "neg_f16", constants);
wgsl_neg_f16, "neg_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1],
wgsl_neg_in_place_f32, "neg_in_place_f32", constants);
wgsl_neg_in_place_f32, "neg_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1],
wgsl_neg_in_place_f16, "neg_in_place_f16", constants);
wgsl_neg_in_place_f16, "neg_in_place_f16", constants);
// STEP
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0],
wgsl_step_f32, "step_f32", constants);
wgsl_step_f32, "step_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0],
wgsl_step_f16, "step_f16", constants);
wgsl_step_f16, "step_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1],
wgsl_step_in_place_f32, "step_in_place_f32", constants);
wgsl_step_in_place_f32, "step_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1],
wgsl_step_in_place_f16, "step_in_place_f16", constants);
wgsl_step_in_place_f16, "step_in_place_f16", constants);
// TANH
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0],
wgsl_tanh_f32, "tanh_f32", constants);
wgsl_tanh_f32, "tanh_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0],
wgsl_tanh_f16, "tanh_f16", constants);
wgsl_tanh_f16, "tanh_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1],
wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants);
wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1],
wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants);
wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants);
// ELU
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0],
wgsl_elu_f32, "elu_f32", constants);
wgsl_elu_f32, "elu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0],
wgsl_elu_f16, "elu_f16", constants);
wgsl_elu_f16, "elu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1],
wgsl_elu_in_place_f32, "elu_in_place_f32", constants);
wgsl_elu_in_place_f32, "elu_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1],
wgsl_elu_in_place_f16, "elu_in_place_f16", constants);
wgsl_elu_in_place_f16, "elu_in_place_f16", constants);
// RELU
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0],
wgsl_relu_f32, "relu_f32", constants);
wgsl_relu_f32, "relu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0],
wgsl_relu_f16, "relu_f16", constants);
wgsl_relu_f16, "relu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1],
wgsl_relu_in_place_f32, "relu_in_place_f32", constants);
wgsl_relu_in_place_f32, "relu_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1],
wgsl_relu_in_place_f16, "relu_in_place_f16", constants);
wgsl_relu_in_place_f16, "relu_in_place_f16", constants);
// SIGMOID
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0],
wgsl_sigmoid_f32, "sigmoid_f32", constants);
wgsl_sigmoid_f32, "sigmoid_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0],
wgsl_sigmoid_f16, "sigmoid_f16", constants);
wgsl_sigmoid_f16, "sigmoid_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1],
wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants);
wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1],
wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants);
wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants);
// GELU
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0],
wgsl_gelu_f32, "gelu_f32", constants);
wgsl_gelu_f32, "gelu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0],
wgsl_gelu_f16, "gelu_f16", constants);
wgsl_gelu_f16, "gelu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1],
wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants);
wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1],
wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants);
wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants);
// GELU_QUICK
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0],
wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0],
wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1],
wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1],
wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0],
wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0],
wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1],
wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1],
wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants);
// SILU
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0],
wgsl_silu_f32, "silu_f32", constants);
wgsl_silu_f32, "silu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0],
wgsl_silu_f16, "silu_f16", constants);
wgsl_silu_f16, "silu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1],
wgsl_silu_in_place_f32, "silu_in_place_f32", constants);
wgsl_silu_in_place_f32, "silu_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1],
wgsl_silu_in_place_f16, "silu_in_place_f16", constants);
wgsl_silu_in_place_f16, "silu_in_place_f16", constants);
// HARDSWISH
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0],
wgsl_hardswish_f32, "hardswish_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0],
wgsl_hardswish_f16, "hardswish_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1],
wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1],
wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0],
wgsl_hardswish_f32, "hardswish_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0],
wgsl_hardswish_f16, "hardswish_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1],
wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1],
wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants);
// HARDSIGMOID
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0],
wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0],
wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1],
wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1],
wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0],
wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0],
wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1],
wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1],
wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants);
// EXP
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0],
wgsl_exp_f32, "exp_f32", constants);
wgsl_exp_f32, "exp_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0],
wgsl_exp_f16, "exp_f16", constants);
wgsl_exp_f16, "exp_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1],
wgsl_exp_in_place_f32, "exp_in_place_f32", constants);
wgsl_exp_in_place_f32, "exp_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1],
wgsl_exp_in_place_f16, "exp_in_place_f16", constants);
wgsl_exp_in_place_f16, "exp_in_place_f16", constants);
// GELU_ERF
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0],
wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0],
wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1],
wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1],
wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], wgsl_gelu_erf_f32,
"gelu_erf_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], wgsl_gelu_erf_f16,
"gelu_erf_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1],
wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1],
wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants);
// XIELU
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0],
wgsl_xielu_f32, "xielu_f32", constants);
wgsl_xielu_f32, "xielu_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0],
wgsl_xielu_f16, "xielu_f16", constants);
wgsl_xielu_f16, "xielu_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1],
wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants);
wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1],
wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants);
wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants);
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
@ -2084,7 +2094,6 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
@ -2103,7 +2112,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
/* .device = */ dev,
/* .context = */ &backend_ctx,
};
//tried
return &backend;
}
@ -2160,7 +2168,6 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
}
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
@ -2294,9 +2301,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_XIELU:
supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
supports_op = supports_op =
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_UNARY_OP_COUNT:
default:
break;
}
@ -2448,7 +2455,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
ggml_webgpu_init_memset_pipeline(ctx);
ggml_webgpu_init_mul_mat_pipeline(ctx);
ggml_webgpu_init_set_rows_pipeline(ctx);
@ -2505,7 +2511,6 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
/* End GGML Backend Registration Interface */
ggml_backend_reg_t ggml_backend_webgpu_reg() {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
@ -2531,7 +2536,6 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
}
ggml_backend_t ggml_backend_webgpu_init(void) {
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
return ggml_backend_webgpu_device_init(dev, nullptr);

View File

@ -3,396 +3,339 @@
[
{
"SHADER_NAME": "abs_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "abs_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "abs_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "abs_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sgn_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sgn_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sgn_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sgn_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "neg_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "neg_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "neg_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "neg_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "step_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "step_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "step_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "step_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "tanh_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "tanh_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "tanh_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "tanh_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "elu_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "elu_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "elu_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "elu_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "relu_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "relu_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "relu_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "relu_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sigmoid_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sigmoid_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_quick_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "silu_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "silu_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "silu_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "silu_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardswish_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardswish_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardswish_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardswish_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "hardsigmoid_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "exp_f32",
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "exp_f16",
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" },
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "exp_in_place_f32",
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "exp_in_place_f16",
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" },
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["NOT_INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "gelu_erf_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
},
"DECLS": ["INPLACE_DFLT_PARAMS"]
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" },
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);"
"FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);",
"EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32"
},
"DECLS": ["NOT_INPLACE_EXT_PARAMS"]
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);"
"FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);",
"EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32"
},
"DECLS": ["NOT_INPLACE_EXT_PARAMS"]
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "xielu_in_place_f32",
"REPLS": {
"TYPE": "f32",
"FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);"
"FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);",
"EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32"
},
"DECLS": ["INPLACE_EXT_PARAMS"]
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "xielu_in_place_f16",
"REPLS": {
"TYPE": "f16",
"FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);"
"FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);",
"EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32"
},
"DECLS": ["INPLACE_EXT_PARAMS"]
"DECLS": ["INPLACE"]
}
]
@ -400,7 +343,7 @@
#define(DECLS)
#decl(NOT_INPLACE_DFLT_PARAMS)
#decl(INPLACE)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@ -408,68 +351,9 @@ var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
#enddecl(INPLACE)
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};
#enddecl(NOT_INPLACE_DFLT_PARAMS)
#decl(INPLACE_DFLT_PARAMS)
@group(0) @binding(1)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32
};
#enddecl(INPLACE_DFLT_PARAMS)
#decl(NOT_INPLACE_EXT_PARAMS)
#decl(NOT_INPLACE)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@ -477,78 +361,7 @@ var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
// XIELU params
alpha_n: u32,
alpha_p: u32,
beta: u32,
eps: u32
};
#enddecl(NOT_INPLACE_EXT_PARAMS)
#decl(INPLACE_EXT_PARAMS)
@group(0) @binding(1)
var<uniform> params: Params;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
// XIELU params
alpha_n: u32,
alpha_p: u32,
beta: u32,
eps: u32
};
#enddecl(INPLACE_EXT_PARAMS)
#enddecl(NOT_INPLACE)
#end(DECLS)
@ -565,6 +378,34 @@ var<storage, read_write> src: array<{{TYPE}}>;
DECLS
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
dst_ne0: u32,
dst_ne1: u32,
dst_ne2: u32,
{{EXT_PARAMS}}
};
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
@ -599,3 +440,4 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
}
#end(SHADER)