Format with clang-format

This commit is contained in:
Reese Levine 2025-07-30 15:06:09 -07:00
parent 01c8ced232
commit bfff27f130
1 changed files with 379 additions and 283 deletions

View File

@ -1,21 +1,25 @@
/*
WebGPU backend implementation.
Note: Use ClangFormat to format this file.
*/
#include "ggml-webgpu.h" #include "ggml-webgpu.h"
#include <webgpu/webgpu_cpp.h>
#include "ggml-impl.h"
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-wgsl-shaders.hpp" #include "ggml-wgsl-shaders.hpp"
#include <webgpu/webgpu_cpp.h>
#include <condition_variable> #include <condition_variable>
#include <cstring> #include <cstring>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
#ifdef GGML_WEBGPU_DEBUG #ifdef GGML_WEBGPU_DEBUG
#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
#else #else
#define WEBGPU_LOG_DEBUG(msg) ((void) 0) # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG #endif // GGML_WEBGPU_DEBUG
/* Constants */ /* Constants */
@ -29,20 +33,24 @@
/* End Constants */ /* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
static void* const webgpu_ptr_base = (void*)(uintptr_t)0x1000; // NOLINT static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
// Always returns the base offset of a tensor, regardless of views. // Always returns the base offset of a tensor, regardless of views.
static uint64_t webgpu_tensor_offset(const ggml_tensor* tensor) { static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
if (tensor->view_src) { if (tensor->view_src) {
return (uint8_t*)tensor->view_src->data - (uint8_t*)webgpu_ptr_base; return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
} }
return (uint8_t*)tensor->data - (uint8_t*)webgpu_ptr_base; return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
} }
/* Struct definitions */ /* Struct definitions */
// Forward reference // Forward reference
static void ggml_webgpu_create_buffer(wgpu::Device& device, wgpu::Buffer& buffer, size_t size, wgpu::BufferUsage usage, const char* label); static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::Buffer & buffer,
size_t size,
wgpu::BufferUsage usage,
const char * label);
struct webgpu_param_bufs { struct webgpu_param_bufs {
wgpu::Buffer host_buf; wgpu::Buffer host_buf;
@ -60,17 +68,23 @@ struct webgpu_param_buf_pool {
for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) { for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) {
wgpu::Buffer host_buf; wgpu::Buffer host_buf;
wgpu::Buffer dev_buf; wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, "ggml_webgpu_host_params_buf"); ggml_webgpu_create_buffer(device,
ggml_webgpu_create_buffer(device, dev_buf, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_dev_params_buf"); host_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
"ggml_webgpu_host_params_buf");
ggml_webgpu_create_buffer(device,
dev_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
"ggml_webgpu_dev_params_buf");
free.push_back({ host_buf, dev_buf }); free.push_back({ host_buf, dev_buf });
} }
} }
webgpu_param_bufs alloc_bufs() { webgpu_param_bufs alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this] { cv.wait(lock, [this] { return !free.empty(); });
return !free.empty();
});
webgpu_param_bufs bufs = free.back(); webgpu_param_bufs bufs = free.back();
free.pop_back(); free.pop_back();
return bufs; return bufs;
@ -84,7 +98,7 @@ struct webgpu_param_buf_pool {
void cleanup() { void cleanup() {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
for (auto& bufs : free) { for (auto & bufs : free) {
bufs.host_buf.Destroy(); bufs.host_buf.Destroy();
bufs.dev_buf.Destroy(); bufs.dev_buf.Destroy();
} }
@ -130,7 +144,7 @@ struct ggml_backend_webgpu_reg_context {
webgpu_context webgpu_ctx; webgpu_context webgpu_ctx;
size_t device_count; size_t device_count;
const char* name; const char * name;
}; };
struct ggml_backend_webgpu_device_context { struct ggml_backend_webgpu_device_context {
@ -152,15 +166,19 @@ struct ggml_backend_webgpu_buffer_context {
wgpu::Buffer buffer; wgpu::Buffer buffer;
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) : ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
webgpu_ctx(std::move(ctx)), buffer(std::move(buf)) { webgpu_ctx(std::move(ctx)),
} buffer(std::move(buf)) {}
}; };
/* End struct definitions */ /* End struct definitions */
/* WebGPU object initializations */ /* WebGPU object initializations */
static void ggml_webgpu_create_pipeline(wgpu::Device& device, wgpu::ComputePipeline& pipeline, const char* shader_code, const char* label, const std::vector<wgpu::ConstantEntry>& constants = {}) { static void ggml_webgpu_create_pipeline(wgpu::Device & device,
wgpu::ComputePipeline & pipeline,
const char * shader_code,
const char * label,
const std::vector<wgpu::ConstantEntry> & constants = {}) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
wgpu::ShaderSourceWGSL shader_source; wgpu::ShaderSourceWGSL shader_source;
shader_source.code = shader_code; shader_source.code = shader_code;
@ -180,7 +198,11 @@ static void ggml_webgpu_create_pipeline(wgpu::Device& device, wgpu::ComputePipel
pipeline = device.CreateComputePipeline(&pipeline_desc); pipeline = device.CreateComputePipeline(&pipeline_desc);
} }
static void ggml_webgpu_create_buffer(wgpu::Device& device, wgpu::Buffer& buffer, size_t size, wgpu::BufferUsage usage, const char* label) { static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::Buffer & buffer,
size_t size,
wgpu::BufferUsage usage,
const char * label) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
wgpu::BufferDescriptor buffer_desc; wgpu::BufferDescriptor buffer_desc;
@ -196,19 +218,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device& device, wgpu::Buffer& buffer
/** WebGPU Actions */ /** WebGPU Actions */
static void ggml_backend_webgpu_wait_on_submission(webgpu_context& ctx) { static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
// Wait for the queue to finish processing all commands // Wait for the queue to finish processing all commands
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) { if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
} }
}), }),
UINT64_MAX UINT64_MAX);
);
} }
static void ggml_backend_webgpu_submit_queue(webgpu_context& ctx) { static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->submit_mutex); std::lock_guard<std::recursive_mutex> lock(ctx->submit_mutex);
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
@ -226,24 +248,34 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context& ctx) {
}); });
} }
static void ggml_backend_webgpu_map_buffer(webgpu_context& ctx, wgpu::Buffer& buffer, wgpu::MapMode mode, size_t offset, size_t size) { static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
ctx->instance.WaitAny(buffer.MapAsync( wgpu::Buffer & buffer,
mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, wgpu::MapMode mode,
size_t offset,
size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode,
offset,
size,
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) { [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) { if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
message.data);
} }
}), }),
UINT64_MAX UINT64_MAX);
);
} }
static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::ComputePipeline& pipeline, std::vector<uint32_t> params, std::vector<wgpu::BindGroupEntry> bind_group_entries, uint32_t wg_x, bool submit_imm = false) { static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
wgpu::ComputePipeline & pipeline,
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_imm = false) {
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
uint32_t* _params = (uint32_t*)params_bufs.host_buf.GetMappedRange();
for (size_t i = 0; i < params.size(); i++) { for (size_t i = 0; i < params.size(); i++) {
_params[i] = params[i]; _params[i] = params[i];
}; };
@ -251,12 +283,10 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
params_bufs.host_buf.Unmap(); params_bufs.host_buf.Unmap();
uint32_t params_bufs_binding_num = bind_group_entries.size(); uint32_t params_bufs_binding_num = bind_group_entries.size();
bind_group_entries.push_back({ bind_group_entries.push_back({ .binding = params_bufs_binding_num,
.binding = params_bufs_binding_num,
.buffer = params_bufs.dev_buf, .buffer = params_bufs.dev_buf,
.offset = 0, .offset = 0,
.size = params_bufs.dev_buf.GetSize() .size = params_bufs.dev_buf.GetSize() });
});
wgpu::BindGroupDescriptor bind_group_desc; wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = pipeline.GetBindGroupLayout(0); bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
@ -265,11 +295,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer( encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
params_bufs.host_buf, 0,
params_bufs.dev_buf, 0,
params_bufs.dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline); pass.SetPipeline(pipeline);
pass.SetBindGroup(0, bind_group); pass.SetBindGroup(0, bind_group);
@ -279,13 +305,13 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
if (submit_imm) { if (submit_imm) {
// Submit immediately // Submit immediately
ctx->queue.Submit(1, &commands); ctx->queue.Submit(1, &commands);
ctx->queue.OnSubmittedWorkDone( ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
wgpu::CallbackMode::AllowSpontaneous,
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) { if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
message.data);
} }
ctx->param_buf_pool.free_bufs({params_bufs}); ctx->param_buf_pool.free_bufs({ params_bufs });
}); });
} else { } else {
// Enqueue commands and only submit if we have enough staged commands // Enqueue commands and only submit if we have enough staged commands
@ -298,20 +324,26 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
} }
} }
static void ggml_backend_webgpu_buffer_memset(webgpu_context& ctx, wgpu::Buffer& buf, uint32_t value, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
std::vector<uint32_t> params = {(uint32_t)offset, (uint32_t)size, value}; wgpu::Buffer & buf,
std::vector<wgpu::BindGroupEntry> entries = {{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }}; uint32_t value,
size_t offset,
size_t size) {
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
};
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
} }
static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor* tensor) { static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) {
return webgpu_tensor_offset(tensor) + tensor->view_offs; return webgpu_tensor_offset(tensor) + tensor->view_offs;
} }
static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor* tensor) { static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) {
ggml_backend_webgpu_buffer_context* ctx = (ggml_backend_webgpu_buffer_context*)tensor->buffer->context; ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
return ctx->buffer; return ctx->buffer;
} }
@ -319,20 +351,20 @@ static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor* tensor) {
/** GGML Backend Interface */ /** GGML Backend Interface */
static const char* ggml_backend_webgpu_name(ggml_backend_t backend) { static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
ggml_backend_webgpu_context* ctx = (ggml_backend_webgpu_context*)backend->context; ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
return ctx->name.c_str(); return ctx->name.c_str();
} }
static void ggml_backend_webgpu_free(ggml_backend_t backend) { static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context* ctx = (ggml_backend_webgpu_context*)backend->context; ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
// TODO: cleanup // TODO: cleanup
GGML_UNUSED(ctx); GGML_UNUSED(ctx);
} }
static void ggml_webgpu_cpy(webgpu_context& ctx, ggml_tensor* src, ggml_tensor* dst) { static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
size_t src_offset = ggml_backend_webgpu_tensor_offset(src); size_t src_offset = ggml_backend_webgpu_tensor_offset(src);
// assumes power of 2 offset alignment // assumes power of 2 offset alignment
size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
@ -341,21 +373,36 @@ static void ggml_webgpu_cpy(webgpu_context& ctx, ggml_tensor* src, ggml_tensor*
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
uint32_t ne = (uint32_t)ggml_nelements(dst); uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { std::vector<uint32_t> params = { ne,
ne, (uint32_t)(src_misalignment / ggml_type_size(src->type)), (uint32_t)(dst_misalignment / ggml_type_size(dst->type)), (uint32_t) (src_misalignment / ggml_type_size(src->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides // Convert byte-strides to element-strides
(uint32_t)(src->nb[0] / ggml_type_size(src->type)), (uint32_t)(src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
(uint32_t)(src->nb[2] / ggml_type_size(src->type)), (uint32_t)(src->nb[3] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t)(dst->nb[0] / ggml_type_size(dst->type)), (uint32_t)(dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t)(dst->nb[2] / ggml_type_size(dst->type)), (uint32_t)(dst->nb[3] / ggml_type_size(dst->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted // Logical shape — same for both tensors even if permuted
(uint32_t)src->ne[0], (uint32_t)src->ne[1], (uint32_t)src->ne[2], (uint32_t)src->ne[3] (uint32_t) src->ne[0],
}; (uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3] };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = ggml_backend_webgpu_tensor_buf(src), .offset = src_offset, .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, { .binding = 0,
{ .binding = 1, .buffer = ggml_backend_webgpu_tensor_buf(dst), .offset = dst_offset, .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } .buffer = ggml_backend_webgpu_tensor_buf(src),
.offset = src_offset,
.size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = dst_offset,
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
}; };
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
@ -363,42 +410,52 @@ static void ggml_webgpu_cpy(webgpu_context& ctx, ggml_tensor* src, ggml_tensor*
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
} }
static void ggml_webgpu_mul_mat(webgpu_context& ctx, ggml_tensor* src0, ggml_tensor* src1, ggml_tensor* dst) { static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
std::vector<uint32_t> params = { std::vector<uint32_t> params = {
(uint32_t)dst->ne[1], // number of rows in result (M) (uint32_t) dst->ne[1], // number of rows in result (M)
(uint32_t)dst->ne[0], // number of columns in result (N) (uint32_t) dst->ne[0], // number of columns in result (N)
(uint32_t)src0->ne[0], // number of columns in src0/src1 (K) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
(uint32_t)(src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1 (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1
(uint32_t)(src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1
(uint32_t)(src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2 (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2
(uint32_t)(src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2 (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2
(uint32_t)(src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3 (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3
(uint32_t)(src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3 (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3
(uint32_t)src0->ne[2], // batch size in dimension 2 (uint32_t) src0->ne[2], // batch size in dimension 2
(uint32_t)src0->ne[3], // batch size in dimension 3 (uint32_t) src0->ne[3], // batch size in dimension 3
(uint32_t)(src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
(uint32_t)(src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
}; };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = ggml_backend_webgpu_tensor_buf(src0), .offset = ggml_backend_webgpu_tensor_offset(src0), .size = ggml_nbytes(src0) }, { .binding = 0,
{ .binding = 1, .buffer = ggml_backend_webgpu_tensor_buf(src1), .offset = ggml_backend_webgpu_tensor_offset(src1), .size = ggml_nbytes(src1) }, .buffer = ggml_backend_webgpu_tensor_buf(src0),
{ .binding = 2, .buffer = ggml_backend_webgpu_tensor_buf(dst), .offset = ggml_backend_webgpu_tensor_offset(dst), .size = ggml_nbytes(dst) } .offset = ggml_backend_webgpu_tensor_offset(src0),
.size = ggml_nbytes(src0) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(src1),
.offset = ggml_backend_webgpu_tensor_offset(src1),
.size = ggml_nbytes(src1) },
{ .binding = 2,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = ggml_backend_webgpu_tensor_offset(dst),
.size = ggml_nbytes(dst) }
}; };
uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; uint32_t wg_x =
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
} }
// Returns true if node has enqueued work into the queue, false otherwise // Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor* node) { static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) { if (ggml_is_empty(node)) {
return false; return false;
} }
WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
ggml_tensor* src0 = node->src[0]; ggml_tensor * src0 = node->src[0];
ggml_tensor* src1 = node->src[1]; ggml_tensor * src1 = node->src[1];
switch (node->op) { switch (node->op) {
// no-ops // no-ops
@ -406,11 +463,13 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor* node) {
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
return false; return false;
case GGML_OP_CPY: { case GGML_OP_CPY:
{
ggml_webgpu_cpy(ctx, src0, node); ggml_webgpu_cpy(ctx, src0, node);
break; break;
} }
case GGML_OP_MUL_MAT: { case GGML_OP_MUL_MAT:
{
ggml_webgpu_mul_mat(ctx, src0, src1, node); ggml_webgpu_mul_mat(ctx, src0, src1, node);
break; break;
} }
@ -420,10 +479,10 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor* node) {
return true; return true;
} }
static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph* cgraph) { static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_webgpu_context* backend_ctx = static_cast<ggml_backend_webgpu_context*>(backend->context); ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
webgpu_context ctx = backend_ctx->webgpu_ctx; webgpu_context ctx = backend_ctx->webgpu_ctx;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
@ -458,34 +517,44 @@ static ggml_backend_i ggml_backend_webgpu_i = {
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()");
ggml_backend_webgpu_buffer_context* ctx = static_cast<ggml_backend_webgpu_buffer_context*>(buffer->context); ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
ctx->buffer.Destroy(); ctx->buffer.Destroy();
} }
// Returns the "fake" base pointer. // Returns the "fake" base pointer.
static void* ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
return webgpu_ptr_base; return webgpu_ptr_base;
} }
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor* tensor, uint8_t value, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
ggml_tensor * tensor,
uint8_t value,
size_t offset,
size_t size) {
if (size == 0) { if (size == 0) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
return; return;
} }
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context* buf_ctx = (ggml_backend_webgpu_buffer_context*)buffer->context; ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
// This is a trick to set all bytes of a u32 to the same 1 byte value. // This is a trick to set all bytes of a u32 to the same 1 byte value.
uint32_t val32 = (uint32_t)value * 0x01010101; uint32_t val32 = (uint32_t) value * 0x01010101;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
} }
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor* tensor, const void* data, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_tensor * tensor,
ggml_backend_webgpu_buffer_context* buf_ctx = (ggml_backend_webgpu_buffer_context*)buffer->context; const void * data,
size_t offset,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
@ -498,17 +567,23 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
// pack the remaining bytes into a uint32_t // pack the remaining bytes into a uint32_t
uint32_t val32 = 0; uint32_t val32 = 0;
for (size_t i = 0; i < remaining_size; i++) { for (size_t i = 0; i < remaining_size; i++) {
((uint8_t*)&val32)[i] = ((const uint8_t*)data)[size - remaining_size + i]; ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
} }
// memset the remaining bytes // memset the remaining bytes
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); ggml_backend_webgpu_buffer_memset(
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
} }
} }
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); const ggml_tensor * tensor,
void * data,
size_t offset,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context* buf_ctx = (ggml_backend_webgpu_buffer_context*)buffer->context; ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
wgpu::Device device = webgpu_ctx->device; wgpu::Device device = webgpu_ctx->device;
@ -522,14 +597,16 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex); std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex);
if (webgpu_ctx->get_tensor_staging_buf == nullptr || if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
// Create a new staging buffer if it doesn't exist or is too small // Create a new staging buffer if it doesn't exist or is too small
if (webgpu_ctx->get_tensor_staging_buf) { if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy(); webgpu_ctx->get_tensor_staging_buf.Destroy();
} }
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, ggml_webgpu_create_buffer(device,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); webgpu_ctx->get_tensor_staging_buf,
final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"get_tensor_staging_buf");
} }
// Copy the data from the buffer to the staging buffer // Copy the data from the buffer to the staging buffer
@ -542,7 +619,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
// Map the staging buffer to read the data // Map the staging buffer to read the data
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size); ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
// Must specify size here since the staging buffer might be larger than the tensor size // Must specify size here since the staging buffer might be larger than the tensor size
const void* mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
// Copy the data from the mapped range to the output buffer // Copy the data from the mapped range to the output buffer
std::memcpy(data, mapped_range, size); std::memcpy(data, mapped_range, size);
@ -550,9 +627,9 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
} }
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t)value << ")"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
ggml_backend_webgpu_buffer_context* buf_ctx = (ggml_backend_webgpu_buffer_context*)buffer->context; ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
} }
@ -572,32 +649,36 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
/* GGML Backend Buffer Type Interface */ /* GGML Backend Buffer Type Interface */
static const char* ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(buft->device->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->device_name.c_str(); return ctx->device_name.c_str();
} }
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(buft->device->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf; wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size, ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); buf,
size,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
"allocated_buffer");
ggml_backend_webgpu_buffer_context* buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
} }
static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(buft->device->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment; return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
} }
// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(buft->device->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
} }
@ -605,18 +686,18 @@ static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_t
/* GGML Backend Device Interface */ /* GGML Backend Device Interface */
static const char* ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(dev->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
return ctx->device_name.c_str(); return ctx->device_name.c_str();
} }
static const char* ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) { static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(dev->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
return ctx->device_desc.c_str(); return ctx->device_desc.c_str();
} }
static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t* free, size_t* total) { static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_webgpu_device_context* ctx = static_cast<ggml_backend_webgpu_device_context*>(dev->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
// TODO: what do we actually want to return here? maxBufferSize might not be the full available memory. // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
*free = ctx->webgpu_ctx->limits.maxBufferSize; *free = ctx->webgpu_ctx->limits.maxBufferSize;
*total = ctx->webgpu_ctx->limits.maxBufferSize; *total = ctx->webgpu_ctx->limits.maxBufferSize;
@ -627,7 +708,7 @@ static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backe
return GGML_BACKEND_DEVICE_TYPE_GPU; return GGML_BACKEND_DEVICE_TYPE_GPU;
} }
static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props* props) { static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_webgpu_device_get_name(dev); props->name = ggml_backend_webgpu_device_get_name(dev);
props->description = ggml_backend_webgpu_device_get_description(dev); props->description = ggml_backend_webgpu_device_get_description(dev);
props->type = ggml_backend_webgpu_device_get_type(dev); props->type = ggml_backend_webgpu_device_get_type(dev);
@ -641,16 +722,17 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct
} }
static ggml_guid_t ggml_backend_webgpu_guid(void) { static ggml_guid_t ggml_backend_webgpu_guid(void) {
static const char* guid_str = "__ggml_webgpu :)"; static const char * guid_str = "__ggml_webgpu :)";
return reinterpret_cast<ggml_guid_t>((void*)guid_str); return reinterpret_cast<ggml_guid_t>((void *) guid_str);
} }
static void ggml_webgpu_init_memset_pipeline(webgpu_context& webgpu_ctx) { static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
// we use the maximum workgroup size for the memset pipeline // we use the maximum workgroup size for the memset pipeline
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled // Size the bytes_per_thread so that the largest buffer size can be handled
webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; webgpu_ctx->memset_bytes_per_thread =
(webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
std::vector<wgpu::ConstantEntry> constants(2); std::vector<wgpu::ConstantEntry> constants(2);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = max_wg_size; constants[0].value = max_wg_size;
@ -659,23 +741,23 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context& webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants);
} }
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context& webgpu_ctx) { static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
} }
static void ggml_webgpu_init_cpy_pipeline(webgpu_context& webgpu_ctx) { static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1); std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
} }
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char* params) { static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params); GGML_UNUSED(params);
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
ggml_backend_webgpu_device_context* dev_ctx = static_cast<ggml_backend_webgpu_device_context*>(dev->context); ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
// Multiple threads may try to initialize the device // Multiple threads may try to initialize the device
@ -686,17 +768,23 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
dev_desc.requiredLimits = &webgpu_ctx->limits; dev_desc.requiredLimits = &webgpu_ctx->limits;
dev_desc.requiredFeatures = webgpu_ctx->features.features; dev_desc.requiredFeatures = webgpu_ctx->features.features;
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount; dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, dev_desc.SetDeviceLostCallback(
[](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) { wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device); GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data); GGML_LOG_ERROR(
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
}); });
dev_desc.SetUncapturedErrorCallback( dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) { [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device); GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data); GGML_LOG_ERROR(
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
}); });
webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::AllowSpontaneous, webgpu_ctx->instance.WaitAny(
webgpu_ctx->adapter.RequestDevice(
&dev_desc,
wgpu::CallbackMode::AllowSpontaneous,
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) { if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
@ -704,8 +792,7 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
} }
webgpu_ctx->device = std::move(device); webgpu_ctx->device = std::move(device);
}), }),
UINT64_MAX UINT64_MAX);
);
GGML_ASSERT(webgpu_ctx->device != nullptr); GGML_ASSERT(webgpu_ctx->device != nullptr);
// Initialize (compute) queue // Initialize (compute) queue
@ -746,7 +833,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ NULL, // defaults to false /* .is_host = */ NULL, // defaults to false
}, },
/* .device = */ dev, /* .device = */
dev,
/* .context = */ NULL, /* .context = */ NULL,
}; };
@ -758,7 +846,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
} }
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor* op) { static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
GGML_UNUSED(dev); GGML_UNUSED(dev);
switch (op->op) { switch (op->op) {
@ -797,13 +885,13 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
/* GGML Backend Registration Interface */ /* GGML Backend Registration Interface */
static const char* ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) { static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
ggml_backend_webgpu_reg_context* ctx = static_cast<ggml_backend_webgpu_reg_context*>(reg->context); ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
return ctx->name; return ctx->name;
} }
static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
ggml_backend_webgpu_reg_context* ctx = static_cast<ggml_backend_webgpu_reg_context*>(reg->context); ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
return ctx->device_count; return ctx->device_count;
} }
@ -813,20 +901,22 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
GGML_ASSERT(index == 0); GGML_ASSERT(index == 0);
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
ggml_backend_webgpu_reg_context* reg_ctx = static_cast<ggml_backend_webgpu_reg_context*>(reg->context); ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
webgpu_context ctx = reg_ctx->webgpu_ctx; webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {}; wgpu::RequestAdapterOptions options = {};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char* message, void* userdata) { auto callback =
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) { if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return; return;
} }
*static_cast<wgpu::Adapter*>(userdata) = std::move(adapter); *static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
}; };
void* userdata = &ctx->adapter; void * userdata = &ctx->adapter;
ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); ctx->instance.WaitAny(
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
GGML_ASSERT(ctx->adapter != nullptr); GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits); ctx->adapter.GetLimits(&ctx->limits);
@ -840,8 +930,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
device_ctx.device_name = GGML_WEBGPU_NAME; device_ctx.device_name = GGML_WEBGPU_NAME;
device_ctx.device_desc = std::string(info.description.data); device_ctx.device_desc = std::string(info.description.data);
GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n", GGML_LOG_INFO(
info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data); "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
"device_desc: %s\n",
info.vendorID,
info.vendor.data,
info.architecture.data,
info.deviceID,
info.device.data,
info.description.data);
// See GGML Backend Device Interface section // See GGML Backend Device Interface section
static ggml_backend_device device = { static ggml_backend_device device = {
@ -852,7 +949,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
return &device; return &device;
} }
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
/* .get_name = */ ggml_backend_webgpu_reg_get_name, /* .get_name = */ ggml_backend_webgpu_reg_get_name,
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count, /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,