Add templated addition, clean up code

This commit is contained in:
Reese Levine 2025-09-04 14:12:44 -07:00
parent 1b16a91183
commit 7f9ee10e75
3 changed files with 179 additions and 282 deletions

View File

@ -125,7 +125,7 @@ struct webgpu_context_struct {
wgpu::ComputePipeline mul_mat_pipeline[30][2];
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline cpy_pipeline;
wgpu::ComputePipeline add_pipeline;
wgpu::ComputePipeline add_pipeline[2];
size_t memset_bytes_per_thread;
@ -233,14 +233,15 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
if (ctx->callback_futures.empty()) {
// no existing callbacks, wait on queue submission
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
}
}),
UINT64_MAX);
ctx->instance.WaitAny(
ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
} else {
// existing callbacks, wait on them
ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
@ -287,10 +288,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
// Check for errrors in SET_ROWS operations
for (auto & error_bufs : staged_set_row_error_bufs) {
wgpu::Future f = error_bufs.host_buf.MapAsync(
wgpu::MapMode::Read,
0,
error_bufs.host_buf.GetSize(),
wgpu::CallbackMode::AllowSpontaneous,
wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
@ -312,10 +310,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
wgpu::MapMode mode,
size_t offset,
size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode,
offset,
size,
wgpu::CallbackMode::AllowSpontaneous,
ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
@ -465,23 +460,17 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3] };
std::vector<uint32_t> params = {
ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@ -510,27 +499,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
error_bufs.host_buf.Unmap();
}
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]),
(uint32_t) (idx->ne[2]) };
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
// Shape of idx
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
@ -598,83 +581,44 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
}
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
size_t src0_offset = ggml_webgpu_tensor_offset(src0);
// assumes power of 2 offset alignment
size_t src0_misalignment = src0_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src0_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t src1_offset = ggml_webgpu_tensor_offset(src1);
size_t src1_misalignment = src1_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
src1_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t dst_offset = ggml_webgpu_tensor_offset(dst);
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
// set up parameters
std::vector<uint32_t> params = {
// number of elements-- determines how many threads to dispatch (one for each addition operation)
(uint32_t) ggml_nelements(dst),
// calculate element strides for each tensor
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// number of elements in each dimension of larger tensors (src0 and dst)
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
// number of elements in each dimension of smaller tensor to be broadcasted (src1)
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) src1->ne[2],
(uint32_t) src1->ne[3],
// offsets in terms of elements instead of bytes
(uint32_t) (src0_misalignment / ggml_type_size(src0->type)),
(uint32_t) (src1_misalignment / ggml_type_size(src1->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = src0_offset,
.size = (ggml_nbytes(src0) + src0_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = src1_offset,
.size = (ggml_nbytes(src1) + src1_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = dst_offset,
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; // max threads in a single workgroup
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; // number of workgroups to dispatch to cover all elements
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline, params, entries, wg_x); // dispatch shader
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline[dst->type], params, entries, wg_x);
}
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
@ -814,8 +758,8 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
}
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
remaining_size);
} else {
// wait for WriteBuffer to complete
ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
@ -849,11 +793,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy();
}
ggml_webgpu_create_buffer(device,
webgpu_ctx->get_tensor_staging_buf,
final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"get_tensor_staging_buf");
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
}
// Copy the data from the buffer to the staging buffer
@ -907,8 +848,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
buf,
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
"allocated_buffer");
@ -989,102 +929,58 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_mul_mat_f32_f32,
"mul_mat_f32_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_mul_mat_f16_f16,
"mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_mul_mat_f16_f32,
"mul_mat_f16_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
wgsl_mul_mat_q4_0_f32,
"mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
wgsl_mul_mat_q4_1_f32,
"mul_mat_q4_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
wgsl_mul_mat_q5_0_f32,
"mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
wgsl_mul_mat_q5_1_f32,
"mul_mat_q5_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
wgsl_mul_mat_q8_0_f32,
"mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
wgsl_mul_mat_q2_k_f32,
"mul_mat_q2_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
wgsl_mul_mat_q3_k_f32,
"mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
wgsl_mul_mat_q4_k_f32,
"mul_mat_q4_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32,
"mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32,
"mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32,
"mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32,
"mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32,
"mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32,
"mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32,
"mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32,
"mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32,
"mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32,
"mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32,
"mul_mat_iq4_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
}
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
@ -1098,10 +994,10 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline, wgsl_add, "add", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", constants);
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@ -1158,9 +1054,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE:
return true;
case GGML_OP_ADD:
return op->type == GGML_TYPE_F32;
return true;
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
@ -1248,14 +1143,14 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {};
auto callback =
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message,
void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
};
void * userdata = &ctx->adapter;
ctx->instance.WaitAny(
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
@ -1277,21 +1172,21 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR(
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR(
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
ctx->instance.WaitAny(ctx->adapter.RequestDevice(
&dev_desc,
wgpu::CallbackMode::AllowSpontaneous,
&dev_desc, wgpu::CallbackMode::AllowSpontaneous,
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
std::string(message).c_str());
return;
}
ctx->device = std::move(device);
@ -1303,14 +1198,10 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ctx->queue = ctx->device.GetQueue();
// Create buffer pool for shader parameters
ctx->param_buf_pool.init(ctx->device,
WEBGPU_NUM_PARAM_BUFS,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
ctx->set_rows_error_buf_pool.init(ctx->device,
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
@ -1322,16 +1213,10 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_host_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"debug_host_buf");
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_dev_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
"debug_dev_buf");
ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
#endif
static ggml_backend_webgpu_device_context device_ctx;
@ -1342,12 +1227,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
GGML_LOG_INFO(
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
"device_desc: %s\n",
info.vendorID,
std::string(info.vendor).c_str(),
std::string(info.architecture).c_str(),
info.deviceID,
std::string(info.device).c_str(),
std::string(info.description).c_str());
info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
std::string(info.device).c_str(), std::string(info.description).c_str());
// See GGML Backend Device Interface section
static ggml_backend_device device = {

View File

@ -1,46 +1,54 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
var<storage, read_write> dst: array<{{TYPE}}>;
struct Params {
ne: u32,
stride_src0_0: u32,
stride_src0_1: u32,
stride_src0_2: u32,
stride_src0_3: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
stride_dst_0: u32,
stride_dst_1: u32,
stride_dst_2: u32,
stride_dst_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
a_ne3: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
};
@group(0) @binding(3)
@ -53,33 +61,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
return;
}
// i = thread id, ranges from 0 --> total ne - 1
// i = thread id, ranges from 0 --> total ne - 1
// represents the position in the flat array a we are adding with array b
var i = gid.x;
var i = gid.x;
// given the index of linear a, we want to compute the 4d index [a_i0, a_i1, a_i2, a_i3]
// we need this because tensor a and b are different shapes
// we need this because tensor a and b are different shapes
// so the same linear index won't work for b, and we can only compute b's linear index from the 4d index of a
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
let src1_idx = b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
@ -91,3 +94,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
// gid.x used for flat indexing into dst and a, since variable i was modified during calcs
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx];
}
#end(SHADER)

View File

@ -46,11 +46,17 @@ def generate_variants(shader_path, output_dir, outfile):
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
else:
decls_map = parse_decls(extract_block(text, "DECLS"))
shader_template = extract_block(text, "SHADER")
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
except ValueError:
decls_map = {}
shader_template = extract_block(text, "SHADER")
for variant in variants:
decls = variant["DECLS"]
if "DECLS" in variant:
decls = variant["DECLS"]
else:
decls = []
decls_code = ""
for key in decls:
if key not in decls_map:
@ -60,7 +66,12 @@ def generate_variants(shader_path, output_dir, outfile):
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile)