Add templated addition, clean up code

This commit is contained in:
Reese Levine 2025-09-04 14:12:44 -07:00
parent 1b16a91183
commit 7f9ee10e75
3 changed files with 179 additions and 282 deletions

View File

@ -125,7 +125,7 @@ struct webgpu_context_struct {
wgpu::ComputePipeline mul_mat_pipeline[30][2]; wgpu::ComputePipeline mul_mat_pipeline[30][2];
wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline cpy_pipeline; wgpu::ComputePipeline cpy_pipeline;
wgpu::ComputePipeline add_pipeline; wgpu::ComputePipeline add_pipeline[2];
size_t memset_bytes_per_thread; size_t memset_bytes_per_thread;
@ -233,11 +233,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex); std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
if (ctx->callback_futures.empty()) { if (ctx->callback_futures.empty()) {
// no existing callbacks, wait on queue submission // no existing callbacks, wait on queue submission
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( ctx->instance.WaitAny(
wgpu::CallbackMode::AllowSpontaneous, ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) { if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
} }
}), }),
UINT64_MAX); UINT64_MAX);
@ -287,10 +288,7 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
// Check for errrors in SET_ROWS operations // Check for errrors in SET_ROWS operations
for (auto & error_bufs : staged_set_row_error_bufs) { for (auto & error_bufs : staged_set_row_error_bufs) {
wgpu::Future f = error_bufs.host_buf.MapAsync( wgpu::Future f = error_bufs.host_buf.MapAsync(
wgpu::MapMode::Read, wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
0,
error_bufs.host_buf.GetSize(),
wgpu::CallbackMode::AllowSpontaneous,
[ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) { if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
@ -312,10 +310,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
wgpu::MapMode mode, wgpu::MapMode mode,
size_t offset, size_t offset,
size_t size) { size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode, ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
offset,
size,
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) { [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) { if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
@ -465,23 +460,17 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { ne, std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides // Convert byte-strides to element-strides
(uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted // Logical shape — same for both tensors even if permuted
(uint32_t) src->ne[0], (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
(uint32_t) src->ne[1], };
(uint32_t) src->ne[2],
(uint32_t) src->ne[3] };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, { .binding = 0,
@ -510,27 +499,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
error_bufs.host_buf.Unmap(); error_bufs.host_buf.Unmap();
} }
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides // Convert byte-strides to element-strides
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
(uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Shape of src // Shape of src
(uint32_t) src->ne[0], (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
(uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3],
// Shape of idx // Shape of idx
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
(uint32_t) (idx->ne[2]) }; };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, { .binding = 0,
@ -598,83 +581,44 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
} }
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
size_t src0_offset = ggml_webgpu_tensor_offset(src0);
// assumes power of 2 offset alignment
size_t src0_misalignment = src0_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src0_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t src1_offset = ggml_webgpu_tensor_offset(src1);
size_t src1_misalignment = src1_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
src1_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t dst_offset = ggml_webgpu_tensor_offset(dst);
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
// set up parameters
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
// number of elements-- determines how many threads to dispatch (one for each addition operation)
(uint32_t) ggml_nelements(dst), (uint32_t) ggml_nelements(dst),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
// calculate element strides for each tensor (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) src0->ne[0],
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) src0->ne[1],
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) src0->ne[2],
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// number of elements in each dimension of larger tensors (src0 and dst)
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
// number of elements in each dimension of smaller tensor to be broadcasted (src1)
(uint32_t) src1->ne[0], (uint32_t) src1->ne[0],
(uint32_t) src1->ne[1], (uint32_t) src1->ne[1],
(uint32_t) src1->ne[2], (uint32_t) src1->ne[2],
(uint32_t) src1->ne[3], (uint32_t) src1->ne[3],
// offsets in terms of elements instead of bytes
(uint32_t) (src0_misalignment / ggml_type_size(src0->type)),
(uint32_t) (src1_misalignment / ggml_type_size(src1->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
}; };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, { .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0), .buffer = ggml_webgpu_tensor_buf(src0),
.offset = src0_offset, .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = (ggml_nbytes(src0) + src0_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1, { .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1), .buffer = ggml_webgpu_tensor_buf(src1),
.offset = src1_offset, .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = (ggml_nbytes(src1) + src1_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2, { .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst), .buffer = ggml_webgpu_tensor_buf(dst),
.offset = dst_offset, .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
}; };
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; // max threads in a single workgroup size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; // number of workgroups to dispatch to cover all elements uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline, params, entries, wg_x); // dispatch shader ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline[dst->type], params, entries, wg_x);
} }
// Returns true if node has enqueued work into the queue, false otherwise // Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) { if (ggml_is_empty(node)) {
@ -814,8 +758,8 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
} }
// memset the remaining bytes // memset the remaining bytes
ggml_backend_webgpu_buffer_memset( ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); remaining_size);
} else { } else {
// wait for WriteBuffer to complete // wait for WriteBuffer to complete
ggml_backend_webgpu_wait_on_submission(webgpu_ctx); ggml_backend_webgpu_wait_on_submission(webgpu_ctx);
@ -849,11 +793,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
if (webgpu_ctx->get_tensor_staging_buf) { if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy(); webgpu_ctx->get_tensor_staging_buf.Destroy();
} }
ggml_webgpu_create_buffer(device, ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
webgpu_ctx->get_tensor_staging_buf, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"get_tensor_staging_buf");
} }
// Copy the data from the buffer to the staging buffer // Copy the data from the buffer to the staging buffer
@ -907,8 +848,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf; wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf,
buf,
(size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1), (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
"allocated_buffer"); "allocated_buffer");
@ -989,102 +929,58 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
} }
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], wgsl_mul_mat_f32_f32, "mul_mat_f32_f32");
wgsl_mul_mat_f32_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
"mul_mat_f32_f32"); wgsl_mul_mat_f16_f16, "mul_mat_f16_f16");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], wgsl_mul_mat_f16_f32, "mul_mat_f16_f32");
wgsl_mul_mat_f16_f16, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
"mul_mat_f16_f16"); wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
wgsl_mul_mat_f16_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
"mul_mat_f16_f32"); wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
wgsl_mul_mat_q4_0_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
"mul_mat_q4_0_f32"); wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
wgsl_mul_mat_q4_1_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
"mul_mat_q4_1_f32"); wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32], wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
wgsl_mul_mat_q5_0_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
"mul_mat_q5_0_f32"); wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32], wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
wgsl_mul_mat_q5_1_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
"mul_mat_q5_1_f32"); wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32], wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
wgsl_mul_mat_q8_0_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
"mul_mat_q8_0_f32"); wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32], wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
wgsl_mul_mat_q2_k_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
"mul_mat_q2_k_f32"); wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32], wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
wgsl_mul_mat_q3_k_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
"mul_mat_q3_k_f32"); wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32], wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
wgsl_mul_mat_q4_k_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
"mul_mat_q4_k_f32"); wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
wgsl_mul_mat_q5_k_f32,
"mul_mat_q5_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
wgsl_mul_mat_q6_k_f32,
"mul_mat_q6_k_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xxs_f32,
"mul_mat_iq2_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq2_xs_f32,
"mul_mat_iq2_xs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
wgsl_mul_mat_iq2_s_f32,
"mul_mat_iq2_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
wgsl_mul_mat_iq3_xxs_f32,
"mul_mat_iq3_xxs_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
wgsl_mul_mat_iq3_s_f32,
"mul_mat_iq3_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
wgsl_mul_mat_iq1_s_f32,
"mul_mat_iq1_s_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
wgsl_mul_mat_iq1_m_f32,
"mul_mat_iq1_m_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
wgsl_mul_mat_iq4_nl_f32,
"mul_mat_iq4_nl_f32");
ggml_webgpu_create_pipeline(webgpu_ctx->device,
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
wgsl_mul_mat_iq4_xs_f32,
"mul_mat_iq4_xs_f32");
} }
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1); std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline( ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants); constants);
} }
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
@ -1098,10 +994,10 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1); std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline, wgsl_add, "add", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", constants);
} }
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params); GGML_UNUSED(params);
@ -1158,9 +1054,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
return true;
case GGML_OP_ADD: case GGML_OP_ADD:
return op->type == GGML_TYPE_F32; return true;
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
@ -1248,8 +1143,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
webgpu_context ctx = reg_ctx->webgpu_ctx; webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {}; wgpu::RequestAdapterOptions options = {};
auto callback = auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message,
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) { void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) { if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return; return;
@ -1277,21 +1172,21 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::CallbackMode::AllowSpontaneous, wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device); GGML_UNUSED(device);
GGML_LOG_ERROR( GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str()); std::string(message).c_str());
}); });
dev_desc.SetUncapturedErrorCallback( dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device); GGML_UNUSED(device);
GGML_LOG_ERROR( GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), std::string(message).c_str()); std::string(message).c_str());
}); });
ctx->instance.WaitAny(ctx->adapter.RequestDevice( ctx->instance.WaitAny(ctx->adapter.RequestDevice(
&dev_desc, &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
wgpu::CallbackMode::AllowSpontaneous,
[ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) { if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
std::string(message).c_str());
return; return;
} }
ctx->device = std::move(device); ctx->device = std::move(device);
@ -1303,14 +1198,10 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ctx->queue = ctx->device.GetQueue(); ctx->queue = ctx->device.GetQueue();
// Create buffer pool for shader parameters // Create buffer pool for shader parameters
ctx->param_buf_pool.init(ctx->device, ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
WEBGPU_NUM_PARAM_BUFS,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
ctx->set_rows_error_buf_pool.init(ctx->device, ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
@ -1322,16 +1213,10 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
#ifdef GGML_WEBGPU_DEBUG #ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers // Initialize debug buffers
ggml_webgpu_create_buffer(ctx->device, ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
ctx->debug_host_buf, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
"debug_host_buf");
ggml_webgpu_create_buffer(ctx->device,
ctx->debug_dev_buf,
WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc,
"debug_dev_buf");
#endif #endif
static ggml_backend_webgpu_device_context device_ctx; static ggml_backend_webgpu_device_context device_ctx;
@ -1342,12 +1227,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
GGML_LOG_INFO( GGML_LOG_INFO(
"ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
"device_desc: %s\n", "device_desc: %s\n",
info.vendorID, info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
std::string(info.vendor).c_str(), std::string(info.device).c_str(), std::string(info.description).c_str());
std::string(info.architecture).c_str(),
info.deviceID,
std::string(info.device).c_str(),
std::string(info.description).c_str());
// See GGML Backend Device Interface section // See GGML Backend Device Interface section
static ggml_backend_device device = { static ggml_backend_device device = {

View File

@ -1,46 +1,54 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
}
},
{
"REPLS": {
"TYPE" : "f16",
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16; enable f16;
@group(0) @binding(0) @group(0) @binding(0)
var<storage, read_write> src0: array<f32>; var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1) @group(0) @binding(1)
var<storage, read_write> src1: array<f32>; var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2) @group(0) @binding(2)
var<storage, read_write> dst: array<f32>; var<storage, read_write> dst: array<{{TYPE}}>;
struct Params { struct Params {
ne: u32, ne: u32,
stride_src0_0: u32, // offsets in elements
stride_src0_1: u32, offset_src0: u32,
stride_src0_2: u32, offset_src1: u32,
stride_src0_3: u32, offset_dst: u32,
stride_src1_0: u32, stride_src1_0: u32,
stride_src1_1: u32, stride_src1_1: u32,
stride_src1_2: u32, stride_src1_2: u32,
stride_src1_3: u32, stride_src1_3: u32,
stride_dst_0: u32,
stride_dst_1: u32,
stride_dst_2: u32,
stride_dst_3: u32,
a_ne0: u32, a_ne0: u32,
a_ne1: u32, a_ne1: u32,
a_ne2: u32, a_ne2: u32,
a_ne3: u32,
b_ne0: u32, b_ne0: u32,
b_ne1: u32, b_ne1: u32,
b_ne2: u32, b_ne2: u32,
b_ne3: u32, b_ne3: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
}; };
@group(0) @binding(3) @group(0) @binding(3)
@ -63,15 +71,11 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0); let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0); i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0; let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0; let a_i0 = i % params.a_ne0;
// handle repetition of b // handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo // index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0; let b_i0 = a_i0 % params.b_ne0;
@ -79,7 +83,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let b_i2 = a_i2 % params.b_ne2; let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3; let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array // compute index for position in b's flat array
let src1_idx = b_i0 * params.stride_src1_0 + let src1_idx = b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 + b_i1 * params.stride_src1_1 +
@ -91,3 +94,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
// gid.x used for flat indexing into dst and a, since variable i was modified during calcs // gid.x used for flat indexing into dst and a, since variable i was modified during calcs
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx]; dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx];
} }
#end(SHADER)

View File

@ -46,11 +46,17 @@ def generate_variants(shader_path, output_dir, outfile):
except ValueError: except ValueError:
write_shader(shader_base_name, text, output_dir, outfile) write_shader(shader_base_name, text, output_dir, outfile)
else: else:
try:
decls_map = parse_decls(extract_block(text, "DECLS")) decls_map = parse_decls(extract_block(text, "DECLS"))
shader_template = extract_block(text, "SHADER") except ValueError:
decls_map = {}
shader_template = extract_block(text, "SHADER")
for variant in variants: for variant in variants:
if "DECLS" in variant:
decls = variant["DECLS"] decls = variant["DECLS"]
else:
decls = []
decls_code = "" decls_code = ""
for key in decls: for key in decls:
if key not in decls_map: if key not in decls_map:
@ -60,7 +66,12 @@ def generate_variants(shader_path, output_dir, outfile):
shader_variant = replace_placeholders(shader_template, variant["REPLS"]) shader_variant = replace_placeholders(shader_template, variant["REPLS"])
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile) write_shader(output_name, final_shader, output_dir, outfile)