unary operators pass ggml tests
This commit is contained in:
parent
c3ae38278a
commit
8a6ec843a5
|
|
@ -144,8 +144,24 @@ struct webgpu_context_struct {
|
|||
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
|
||||
wgpu::ComputePipeline scale_pipeline[2]; // inplace
|
||||
wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
|
||||
wgpu::ComputePipeline neg_pipeline;
|
||||
wgpu::ComputePipeline neg_ip_pipeline;
|
||||
wgpu::ComputePipeline unary_pipeline[16][2][2];
|
||||
|
||||
|
||||
/* wgpu::ComputePipeline abs_pipeline[2][2]; // abs
|
||||
wgpu::ComputePipeline sgn_pipeline[2][2]; // sgn
|
||||
wgpu::ComputePipeline neg_pipeline[2][2]; // neg
|
||||
wgpu::ComputePipeline step_pipeline[2][2]; // step
|
||||
wgpu::ComputePipeline tanh_pipeline[2][2]; // tanh
|
||||
wgpu::ComputePipeline elu_pipeline[2][2]; // elu
|
||||
wgpu::ComputePipeline relu_pipeline[2][2]; // relu
|
||||
wgpu::ComputePipeline sigmoid_pipeline[2][2]; // sigmoid
|
||||
wgpu::ComputePipeline gelu_pipeline[2][2]; // gelu
|
||||
wgpu::ComputePipeline gelu_quick_pipeline[2][2]; // gelu_quick
|
||||
wgpu::ComputePipeline silu_pipeline[2][2]; // silu (a.k.a. swish)
|
||||
wgpu::ComputePipeline hardswish_pipeline[2][2]; // hardswish
|
||||
wgpu::ComputePipeline hardsigmoid_pipeline[2][2]; // hardsigmoid
|
||||
wgpu::ComputePipeline exp_pipeline[2][2]; // exp
|
||||
wgpu::ComputePipeline gelu_erf_pipeline[2][2]; // gelu_erf */
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
|
||||
|
|
@ -250,6 +266,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
|||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
||||
|
||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
if (ctx->callback_futures.empty()) {
|
||||
// no existing callbacks, wait on queue submission
|
||||
|
|
@ -274,6 +291,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
|
|||
}
|
||||
|
||||
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
|
||||
|
||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()");
|
||||
if (ctx->staged_command_bufs.empty()) {
|
||||
|
|
@ -373,6 +391,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
|||
uint32_t wg_x,
|
||||
const char * bind_group_label = nullptr,
|
||||
bool submit_and_wait = false) {
|
||||
|
||||
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
|
||||
|
|
@ -491,39 +510,6 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
|||
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Logical shapes
|
||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||
// For set rows specifically, we need to check if src and idx are empty tensors.
|
||||
if (ggml_is_empty(src) || ggml_is_empty(idx)) {
|
||||
|
|
@ -659,6 +645,83 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
|||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Logical shapes
|
||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_unary_op( webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
wgpu::ComputePipeline & pipeline,
|
||||
bool in_place) {
|
||||
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Logical shapes
|
||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
||||
};
|
||||
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
|
||||
};
|
||||
if (!in_place) {
|
||||
entries.push_back({ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
|
||||
}
|
||||
|
||||
static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
|
|
@ -994,38 +1057,12 @@ static void ggml_webgpu_soft_max(webgpu_context & ctx,
|
|||
ggml_nrows(dst), ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_neg( webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
wgpu::ComputePipeline & pipeline,
|
||||
bool in_place) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) ggml_nelements(dst)
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
|
||||
};
|
||||
if (!in_place) {
|
||||
entries.push_back({ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
size_t max_wg_size = ctx->max_wg_size_x;
|
||||
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
|
||||
// Returns true if node has enqueued work into the queue, false otherwise
|
||||
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||
|
||||
if (ggml_is_empty(node)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1035,6 +1072,8 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|||
ggml_tensor * src1 = node->src[1];
|
||||
ggml_tensor * src2 = node->src[2];
|
||||
|
||||
|
||||
|
||||
switch (node->op) {
|
||||
// no-ops
|
||||
case GGML_OP_NONE:
|
||||
|
|
@ -1092,29 +1131,23 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|||
case GGML_OP_SCALE:
|
||||
ggml_webgpu_scale(ctx, src0, node);
|
||||
break;
|
||||
case GGML_OP_UNARY: {
|
||||
// if unary, switch on unary operators
|
||||
const ggml_unary_op unary_op = ggml_get_unary_op(node);
|
||||
switch (unary_op) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_neg(ctx, src0, node, ctx->neg_ip_pipeline, true);
|
||||
} else {
|
||||
ggml_webgpu_neg(ctx, src0, src1, ctx->neg_pipeline, false);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
case GGML_OP_UNARY:
|
||||
{
|
||||
const ggml_unary_op UNARY_OP = ggml_get_unary_op(node);
|
||||
int in_place = ggml_webgpu_tensor_equal(src0, node);
|
||||
ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place);
|
||||
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||
|
||||
ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
|
||||
|
|
@ -1296,6 +1329,8 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer
|
|||
|
||||
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
||||
size_t size) {
|
||||
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
|
||||
|
||||
|
|
@ -1307,6 +1342,8 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
|
|||
|
||||
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
|
||||
|
||||
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
|
||||
}
|
||||
|
||||
|
|
@ -1670,19 +1707,162 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
|
|||
constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_neg_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_pipeline, wgsl_neg_f32, "neg_f32",
|
||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_pipeline, wgsl_neg_f16, "neg_f16",
|
||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_ip_pipeline, wgsl_neg_in_place_f32, "neg_in_place_f32",
|
||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_ip_pipeline, wgsl_neg_in_place_f16, "neg_in_place_f16",
|
||||
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
|
||||
|
||||
static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
|
||||
|
||||
// ABS
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0],
|
||||
wgsl_abs_f32, "abs_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0],
|
||||
wgsl_abs_f16, "abs_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1],
|
||||
wgsl_abs_in_place_f32, "abs_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1],
|
||||
wgsl_abs_in_place_f16, "abs_in_place_f16", constants);
|
||||
|
||||
// SGN
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0],
|
||||
wgsl_sgn_f32, "sgn_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0],
|
||||
wgsl_sgn_f16, "sgn_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1],
|
||||
wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1],
|
||||
wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants);
|
||||
|
||||
// NEG
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0],
|
||||
wgsl_neg_f32, "neg_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0],
|
||||
wgsl_neg_f16, "neg_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1],
|
||||
wgsl_neg_in_place_f32, "neg_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1],
|
||||
wgsl_neg_in_place_f16, "neg_in_place_f16", constants);
|
||||
|
||||
// STEP
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0],
|
||||
wgsl_step_f32, "step_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0],
|
||||
wgsl_step_f16, "step_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1],
|
||||
wgsl_step_in_place_f32, "step_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1],
|
||||
wgsl_step_in_place_f16, "step_in_place_f16", constants);
|
||||
|
||||
// TANH
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0],
|
||||
wgsl_tanh_f32, "tanh_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0],
|
||||
wgsl_tanh_f16, "tanh_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1],
|
||||
wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1],
|
||||
wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants);
|
||||
|
||||
// ELU
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0],
|
||||
wgsl_elu_f32, "elu_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0],
|
||||
wgsl_elu_f16, "elu_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1],
|
||||
wgsl_elu_in_place_f32, "elu_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1],
|
||||
wgsl_elu_in_place_f16, "elu_in_place_f16", constants);
|
||||
|
||||
// RELU
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0],
|
||||
wgsl_relu_f32, "relu_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0],
|
||||
wgsl_relu_f16, "relu_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1],
|
||||
wgsl_relu_in_place_f32, "relu_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1],
|
||||
wgsl_relu_in_place_f16, "relu_in_place_f16", constants);
|
||||
|
||||
// SIGMOID
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0],
|
||||
wgsl_sigmoid_f32, "sigmoid_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0],
|
||||
wgsl_sigmoid_f16, "sigmoid_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1],
|
||||
wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1],
|
||||
wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants);
|
||||
|
||||
// GELU
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0],
|
||||
wgsl_gelu_f32, "gelu_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0],
|
||||
wgsl_gelu_f16, "gelu_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1],
|
||||
wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1],
|
||||
wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants);
|
||||
|
||||
// GELU_QUICK
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0],
|
||||
wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0],
|
||||
wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1],
|
||||
wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1],
|
||||
wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants);
|
||||
|
||||
// SILU
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0],
|
||||
wgsl_silu_f32, "silu_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0],
|
||||
wgsl_silu_f16, "silu_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1],
|
||||
wgsl_silu_in_place_f32, "silu_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1],
|
||||
wgsl_silu_in_place_f16, "silu_in_place_f16", constants);
|
||||
|
||||
// HARDSWISH
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0],
|
||||
wgsl_hardswish_f32, "hardswish_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0],
|
||||
wgsl_hardswish_f16, "hardswish_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1],
|
||||
wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1],
|
||||
wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants);
|
||||
|
||||
// HARDSIGMOID
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0],
|
||||
wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0],
|
||||
wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1],
|
||||
wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1],
|
||||
wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants);
|
||||
|
||||
// EXP
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0],
|
||||
wgsl_exp_f32, "exp_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0],
|
||||
wgsl_exp_f16, "exp_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1],
|
||||
wgsl_exp_in_place_f32, "exp_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1],
|
||||
wgsl_exp_in_place_f16, "exp_in_place_f16", constants);
|
||||
|
||||
// GELU_ERF
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0],
|
||||
wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0],
|
||||
wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1],
|
||||
wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1],
|
||||
wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants);
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
|
||||
GGML_UNUSED(params);
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
|
||||
|
|
@ -1701,12 +1881,13 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
|
|||
/* .device = */ dev,
|
||||
/* .context = */ &backend_ctx,
|
||||
};
|
||||
|
||||
//tried
|
||||
return &backend;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
// See GGML Backend Buffer Type Interface section
|
||||
|
||||
static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
|
||||
|
|
@ -1757,6 +1938,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
|||
}
|
||||
|
||||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
|
||||
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
|
||||
|
||||
webgpu_context webgpu_ctx = ctx->webgpu_ctx;
|
||||
|
|
@ -1866,6 +2048,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_OP_SCALE:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||
(src1 ? (src1->type == op->type) : true);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -1888,6 +2074,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||
}
|
||||
|
||||
|
||||
return supports_op;
|
||||
}
|
||||
|
||||
|
|
@ -1929,6 +2117,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
GGML_ASSERT(index == 0);
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
|
||||
|
||||
|
||||
|
||||
ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
|
||||
|
||||
webgpu_context ctx = reg_ctx->webgpu_ctx;
|
||||
|
|
@ -1996,6 +2186,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
|
||||
|
||||
ggml_webgpu_init_memset_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||
|
|
@ -2009,6 +2200,24 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
ggml_webgpu_init_rope_pipeline(ctx);
|
||||
ggml_webgpu_init_glu_pipeline(ctx);
|
||||
ggml_webgpu_init_scale_pipeline(ctx);
|
||||
ggml_webgpu_init_unary_pipeline(ctx);
|
||||
|
||||
|
||||
/* ggml_webgpu_init_abs_pipeline(ctx);
|
||||
ggml_webgpu_init_sgn_pipeline(ctx);
|
||||
ggml_webgpu_init_neg_pipeline(ctx);
|
||||
ggml_webgpu_init_step_pipeline(ctx);
|
||||
ggml_webgpu_init_tanh_pipeline(ctx);
|
||||
ggml_webgpu_init_elu_pipeline(ctx);
|
||||
ggml_webgpu_init_relu_pipeline(ctx);
|
||||
ggml_webgpu_init_sigmoid_pipeline(ctx);
|
||||
ggml_webgpu_init_gelu_pipeline(ctx);
|
||||
ggml_webgpu_init_gelu_quick_pipeline(ctx);
|
||||
ggml_webgpu_init_silu_pipeline(ctx);
|
||||
ggml_webgpu_init_hardswish_pipeline(ctx);
|
||||
ggml_webgpu_init_hardsigmoid_pipeline(ctx);
|
||||
ggml_webgpu_init_exp_pipeline(ctx);
|
||||
ggml_webgpu_init_gelu_erf_pipeline(ctx); */
|
||||
|
||||
#ifdef GGML_WEBGPU_DEBUG
|
||||
// Initialize debug buffers
|
||||
|
|
@ -2035,6 +2244,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
/* .reg = */ reg,
|
||||
/* .context = */ &device_ctx,
|
||||
};
|
||||
|
||||
|
||||
return &device;
|
||||
}
|
||||
|
||||
|
|
@ -2048,6 +2259,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
|
|||
/* End GGML Backend Registration Interface */
|
||||
|
||||
ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
|
||||
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
|
||||
|
||||
webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
|
||||
|
|
@ -2073,8 +2285,9 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
|||
}
|
||||
|
||||
ggml_backend_t ggml_backend_webgpu_init(void) {
|
||||
|
||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
|
||||
|
||||
|
||||
return ggml_backend_webgpu_device_init(dev, nullptr);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,467 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "abs_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "abs_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "abs_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "abs_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_in_place_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_in_place_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_in_place_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_in_place_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_in_place_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_in_place_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_in_place_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_in_place_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i])));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_in_place_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_in_place_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_in_place_f32",
|
||||
"REPLS": {
|
||||
"TYPE": "f32",
|
||||
"FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_in_place_f16",
|
||||
"REPLS": {
|
||||
"TYPE": "f16",
|
||||
"FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
|
||||
fn update(dst_i: u32, src_i: u32) {
|
||||
{{FUNC}}
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
|
||||
fn update(dst_i: u32, src_i: u32) {
|
||||
{{FUNC}} // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458
|
||||
}
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
struct Params {
|
||||
ne: u32, // total number of elements
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements) — may be permuted
|
||||
stride_src0: u32,
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Logical shapes
|
||||
src_ne0: u32,
|
||||
src_ne1: u32,
|
||||
src_ne2: u32,
|
||||
|
||||
dst_ne0: u32,
|
||||
dst_ne1: u32,
|
||||
dst_ne2: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{TYPE}}>;
|
||||
|
||||
|
||||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||
i = i % (params.src_ne1 * params.src_ne0);
|
||||
let i1 = i / params.src_ne0;
|
||||
let i0 = i % params.src_ne0;
|
||||
|
||||
var j = gid.x;
|
||||
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||
j = j % (params.dst_ne1 * params.dst_ne0);
|
||||
let j1 = j / params.dst_ne0;
|
||||
let j0 = j % params.dst_ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
||||
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||
|
||||
|
||||
update(params.offset_dst + dst_idx, params.offset_src + src_idx);
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
Loading…
Reference in New Issue