Merge pull request #1 from reeselevine/addition

F32 Support for Addition Operation
This commit is contained in:
Reese Levine 2025-09-09 15:42:27 -07:00 committed by GitHub
commit efc0cb0c28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 503 additions and 209 deletions

View File

@ -125,6 +125,10 @@ struct webgpu_context_struct {
wgpu::ComputePipeline mul_mat_pipeline[30][2];
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline cpy_pipeline;
wgpu::ComputePipeline add_pipeline[2];
wgpu::ComputePipeline add_ip_pipeline[2];
wgpu::ComputePipeline mul_pipeline[2];
wgpu::ComputePipeline mul_ip_pipeline[2];
size_t memset_bytes_per_thread;
@ -232,14 +236,15 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
if (ctx->callback_futures.empty()) {
// no existing callbacks, wait on queue submission
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
}
}),
UINT64_MAX);
ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
} else {
// existing callbacks, wait on them
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
@ -286,10 +291,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
// Check for errrors in SET_ROWS operations
for (auto & error_bufs : staged_set_row_error_bufs) {
wgpu::Future f = error_bufs.host_buf.MapAsync(
wgpu::MapMode::Read,
0,
error_bufs.host_buf.GetSize(),
wgpu::CallbackMode::AllowSpontaneous,
wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
@ -311,10 +313,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
wgpu::MapMode mode,
size_t offset,
size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode,
offset,
size,
wgpu::CallbackMode::AllowSpontaneous,
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
@ -351,7 +350,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_and_wait = false) {
const char * bind_group_label = nullptr,
bool submit_and_wait = false) {
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
@ -372,6 +372,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
bind_group_desc.entryCount = bind_group_entries.size();
bind_group_desc.entries = bind_group_entries.data();
if (bind_group_label) {
bind_group_desc.label = bind_group_label;
}
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
@ -417,7 +420,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
};
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
}
/** End WebGPU Actions */
@ -461,26 +464,26 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
}
// Used to determine if two tensors are the same for in-place operations
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3] };
std::vector<uint32_t> params = {
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@ -495,7 +498,7 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@ -509,27 +512,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
error_bufs.host_buf.Unmap();
}
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2]) };
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@ -553,7 +550,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ctx->staged_set_row_error_bufs.push_back(error_bufs);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@ -593,7 +590,54 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
uint32_t wg_x =
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
ggml_op_name(dst->op));
}
static void ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
wgpu::ComputePipeline & pipeline,
bool in_place) {
std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
};
if (!in_place) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
}
// Returns true if node has enqueued work into the queue, false otherwise
@ -629,6 +673,24 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
ggml_webgpu_mul_mat(ctx, src0, src1, node);
break;
}
case GGML_OP_ADD:
{
if (ggml_webgpu_tensor_equal(src0, node)) {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
} else {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
}
break;
}
case GGML_OP_MUL:
{
if (ggml_webgpu_tensor_equal(src0, node)) {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
} else {
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
}
break;
}
default:
return false;
}
@ -730,8 +792,8 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
}
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
remaining_size);
} else {
// wait for WriteBuffer to complete
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
@ -765,11 +827,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy();
}
ggml_webgpu_create_buffer(device,
webgpu_ctx->get_tensor_staging_buf,
final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"get_tensor_staging_buf");
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
}
// Copy the data from the buffer to the staging buffer
@ -823,8 +882,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
buf,
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
"allocated_buffer");
@ -905,102 +963,58 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_mul_mat_f32_f32,
"mul_mat_f32_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_mul_mat_f16_f16,
"mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_mul_mat_f16_f32,
"mul_mat_f16_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
wgsl_mul_mat_q4_0_f32,
"mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
wgsl_mul_mat_q4_1_f32,
"mul_mat_q4_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
wgsl_mul_mat_q5_0_f32,
"mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
wgsl_mul_mat_q5_1_f32,
"mul_mat_q5_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
wgsl_mul_mat_q8_0_f32,
"mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
wgsl_mul_mat_q2_k_f32,
"mul_mat_q2_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
wgsl_mul_mat_q3_k_f32,
"mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
wgsl_mul_mat_q4_k_f32,
"mul_mat_q4_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32,
"mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32,
"mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32,
"mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32,
"mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32,
"mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32,
"mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32,
"mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32,
"mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32,
"mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32,
"mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32,
"mul_mat_iq4_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
@ -1010,6 +1024,34 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
"add_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
"add_in_place_f16", constants);
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
"mul_in_place_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
"mul_in_place_f16", constants);
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@ -1060,21 +1102,30 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_UNUSED(dev);
bool supports_op = false;
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE:
return true;
supports_op = true;
break;
case GGML_OP_ADD:
case GGML_OP_MUL:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
(op->src[1]->type == op->type);
break;
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
break;
case GGML_OP_MUL_MAT:
{
switch (op->src[1]->type) {
case GGML_TYPE_F16:
return op->src[0]->type == GGML_TYPE_F16;
supports_op = (op->src[0]->type == GGML_TYPE_F16);
break;
case GGML_TYPE_F32:
switch (op->src[0]->type) {
case GGML_TYPE_F32:
@ -1098,17 +1149,26 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return true;
supports_op = true;
break;
default:
return false;
break;
}
default:
return false;
break;
}
}
default:
return false;
break;
}
#ifdef GGML_WEBGPU_DEBUG
if (!supports_op) {
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
}
#endif
return supports_op;
}
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
@ -1154,14 +1214,14 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {};
auto callback =
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message,
void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
};
void * userdata = &ctx->adapter;
ctx->instance.WaitAny(
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
@ -1183,21 +1243,21 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR(
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR(
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
&dev_desc,
wgpu::CallbackMode::AllowSpontaneous,
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
std::string(message).c_str());
return;
}
ctx->device = std::move(device);
@ -1209,14 +1269,10 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ctx->queue = ctx->device.GetQueue();
// Create buffer pool for shader parameters
ctx->param_buf_pool.init(ctx->device,
WEBGPU_NUM_PARAM_BUFS,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
ctx->set_rows_error_buf_pool.init(ctx->device,
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
@ -1224,19 +1280,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ggml_webgpu_init_mul_mat_pipeline(ctx);
ggml_webgpu_init_set_rows_pipeline(ctx);
ggml_webgpu_init_cpy_pipeline(ctx);
ggml_webgpu_init_add_pipeline(ctx);
ggml_webgpu_init_mul_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_host_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"debug_host_buf");
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_dev_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
"debug_dev_buf");
ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
#endif
static ggml_backend_webgpu_device_context device_ctx;
@ -1247,12 +1299,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
GGML_LOG_INFO(
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
"device_desc: %s\n",
info.vendorID,
std::string(info.vendor).c_str(),
std::string(info.architecture).c_str(),
info.deviceID,
std::string(info.device).c_str(),
std::string(info.description).c_str());
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
std::string(info.device).c_str(), std::string(info.description).c_str());
// See GGML Backend Device Interface section
static ggml_backend_device device = {

View File

@ -0,0 +1,44 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -0,0 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -0,0 +1,45 @@
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
return b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
}

View File

@ -26,6 +26,24 @@ def replace_placeholders(shader_text, replacements):
shader_text = re.sub(pattern, str(val), shader_text)
return shader_text
def expand_includes(shader, input_dir):
"""
Replace #include "file" lines in the text with the contents of that file.
Searches for files relative to input_dir.
"""
include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
def replacer(match):
fname = match.group(1)
file_path = os.path.join(input_dir, fname)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Included file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
included_code = f.read()
# Recursively expand includes inside the included file
return expand_includes(included_code, input_dir)
return include_pattern.sub(replacer, shader)
def write_shader(shader_name, shader_code, output_dir, outfile):
if output_dir:
@ -35,8 +53,9 @@ def write_shader(shader_name, shader_code, output_dir, outfile):
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
def generate_variants(shader_path, output_dir, outfile):
shader_base_name = shader_path.split("/")[-1].split(".")[0]
def generate_variants(fname, input_dir, output_dir, outfile):
shader_path = os.path.join(input_dir, fname)
shader_base_name = fname.split(".")[0]
with open(shader_path, "r", encoding="utf-8") as f:
text = f.read()
@ -46,11 +65,18 @@ def generate_variants(shader_path, output_dir, outfile):
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
else:
decls_map = parse_decls(extract_block(text, "DECLS"))
shader_template = extract_block(text, "SHADER")
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
except ValueError:
decls_map = {}
shader_template = extract_block(text, "SHADER")
shader_template = expand_includes(shader_template, input_dir)
for variant in variants:
decls = variant["DECLS"]
if "DECLS" in variant:
decls = variant["DECLS"]
else:
decls = []
decls_code = ""
for key in decls:
if key not in decls_map:
@ -60,7 +86,12 @@ def generate_variants(shader_path, output_dir, outfile):
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)
@ -78,7 +109,7 @@ def main():
out.write("// Auto-generated shader embedding\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
generate_variants(fname, args.input_dir, args.output_dir, out)
if __name__ == "__main__":

View File

@ -0,0 +1,44 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)

View File

@ -0,0 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
}
}
#end(SHADER)