Format with clang-format

This commit is contained in:
Reese Levine 2025-07-30 15:06:09 -07:00
parent 01c8ced232
commit bfff27f130
1 changed files with 379 additions and 283 deletions

View File

@ -1,12 +1,16 @@
/*
WebGPU backend implementation.
Note: Use ClangFormat to format this file.
*/
#include "ggml-webgpu.h" #include "ggml-webgpu.h"
#include <webgpu/webgpu_cpp.h>
#include "ggml-impl.h"
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-wgsl-shaders.hpp" #include "ggml-wgsl-shaders.hpp"
#include <webgpu/webgpu_cpp.h>
#include <condition_variable> #include <condition_variable>
#include <cstring> #include <cstring>
#include <mutex> #include <mutex>
@ -42,7 +46,11 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor* tensor) {
/* Struct definitions */ /* Struct definitions */
// Forward reference // Forward reference
static void ggml_webgpu_create_buffer(wgpu::Device& device, wgpu::Buffer& buffer, size_t size, wgpu::BufferUsage usage, const char* label); static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::Buffer & buffer,
size_t size,
wgpu::BufferUsage usage,
const char * label);
struct webgpu_param_bufs { struct webgpu_param_bufs {
wgpu::Buffer host_buf; wgpu::Buffer host_buf;
@ -60,17 +68,23 @@ struct webgpu_param_buf_pool {
for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) { for (int i = 0; i < WEBGPU_NUM_PARAM_BUFS; i++) {
wgpu::Buffer host_buf; wgpu::Buffer host_buf;
wgpu::Buffer dev_buf; wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, "ggml_webgpu_host_params_buf"); ggml_webgpu_create_buffer(device,
ggml_webgpu_create_buffer(device, dev_buf, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_dev_params_buf"); host_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite,
"ggml_webgpu_host_params_buf");
ggml_webgpu_create_buffer(device,
dev_buf,
WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
"ggml_webgpu_dev_params_buf");
free.push_back({ host_buf, dev_buf }); free.push_back({ host_buf, dev_buf });
} }
} }
webgpu_param_bufs alloc_bufs() { webgpu_param_bufs alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [this] { cv.wait(lock, [this] { return !free.empty(); });
return !free.empty();
});
webgpu_param_bufs bufs = free.back(); webgpu_param_bufs bufs = free.back();
free.pop_back(); free.pop_back();
return bufs; return bufs;
@ -152,15 +166,19 @@ struct ggml_backend_webgpu_buffer_context {
wgpu::Buffer buffer; wgpu::Buffer buffer;
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) : ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
webgpu_ctx(std::move(ctx)), buffer(std::move(buf)) { webgpu_ctx(std::move(ctx)),
} buffer(std::move(buf)) {}
}; };
/* End struct definitions */ /* End struct definitions */
/* WebGPU object initializations */ /* WebGPU object initializations */
static void ggml_webgpu_create_pipeline(wgpu::Device& device, wgpu::ComputePipeline& pipeline, const char* shader_code, const char* label, const std::vector<wgpu::ConstantEntry>& constants = {}) { static void ggml_webgpu_create_pipeline(wgpu::Device & device,
wgpu::ComputePipeline & pipeline,
const char * shader_code,
const char * label,
const std::vector<wgpu::ConstantEntry> & constants = {}) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
wgpu::ShaderSourceWGSL shader_source; wgpu::ShaderSourceWGSL shader_source;
shader_source.code = shader_code; shader_source.code = shader_code;
@ -180,7 +198,11 @@ static void ggml_webgpu_create_pipeline(wgpu::Device& device, wgpu::ComputePipel
pipeline = device.CreateComputePipeline(&pipeline_desc); pipeline = device.CreateComputePipeline(&pipeline_desc);
} }
static void ggml_webgpu_create_buffer(wgpu::Device& device, wgpu::Buffer& buffer, size_t size, wgpu::BufferUsage usage, const char* label) { static void ggml_webgpu_create_buffer(wgpu::Device & device,
wgpu::Buffer & buffer,
size_t size,
wgpu::BufferUsage usage,
const char * label) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()");
wgpu::BufferDescriptor buffer_desc; wgpu::BufferDescriptor buffer_desc;
@ -198,14 +220,14 @@ static void ggml_webgpu_create_buffer(wgpu::Device& device, wgpu::Buffer& buffer
static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
// Wait for the queue to finish processing all commands // Wait for the queue to finish processing all commands
ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) { if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data);
} }
}), }),
UINT64_MAX UINT64_MAX);
);
} }
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
@ -226,23 +248,33 @@ static void ggml_backend_webgpu_submit_queue(webgpu_context& ctx) {
}); });
} }
static void ggml_backend_webgpu_map_buffer(webgpu_context& ctx, wgpu::Buffer& buffer, wgpu::MapMode mode, size_t offset, size_t size) { static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
ctx->instance.WaitAny(buffer.MapAsync( wgpu::Buffer & buffer,
mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, wgpu::MapMode mode,
size_t offset,
size_t size) {
ctx->instance.WaitAny(buffer.MapAsync(mode,
offset,
size,
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::MapAsyncStatus status, wgpu::StringView message) { [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
if (status != wgpu::MapAsyncStatus::Success) { if (status != wgpu::MapAsyncStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
message.data);
} }
}), }),
UINT64_MAX UINT64_MAX);
);
} }
static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::ComputePipeline& pipeline, std::vector<uint32_t> params, std::vector<wgpu::BindGroupEntry> bind_group_entries, uint32_t wg_x, bool submit_imm = false) { static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx,
wgpu::ComputePipeline & pipeline,
std::vector<uint32_t> params,
std::vector<wgpu::BindGroupEntry> bind_group_entries,
uint32_t wg_x,
bool submit_imm = false) {
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
for (size_t i = 0; i < params.size(); i++) { for (size_t i = 0; i < params.size(); i++) {
_params[i] = params[i]; _params[i] = params[i];
@ -251,12 +283,10 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
params_bufs.host_buf.Unmap(); params_bufs.host_buf.Unmap();
uint32_t params_bufs_binding_num = bind_group_entries.size(); uint32_t params_bufs_binding_num = bind_group_entries.size();
bind_group_entries.push_back({ bind_group_entries.push_back({ .binding = params_bufs_binding_num,
.binding = params_bufs_binding_num,
.buffer = params_bufs.dev_buf, .buffer = params_bufs.dev_buf,
.offset = 0, .offset = 0,
.size = params_bufs.dev_buf.GetSize() .size = params_bufs.dev_buf.GetSize() });
});
wgpu::BindGroupDescriptor bind_group_desc; wgpu::BindGroupDescriptor bind_group_desc;
bind_group_desc.layout = pipeline.GetBindGroupLayout(0); bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
@ -265,11 +295,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
encoder.CopyBufferToBuffer( encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
params_bufs.host_buf, 0,
params_bufs.dev_buf, 0,
params_bufs.dev_buf.GetSize()
);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline); pass.SetPipeline(pipeline);
pass.SetBindGroup(0, bind_group); pass.SetBindGroup(0, bind_group);
@ -279,11 +305,11 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
if (submit_imm) { if (submit_imm) {
// Submit immediately // Submit immediately
ctx->queue.Submit(1, &commands); ctx->queue.Submit(1, &commands);
ctx->queue.OnSubmittedWorkDone( ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
wgpu::CallbackMode::AllowSpontaneous,
[ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) { if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
message.data);
} }
ctx->param_buf_pool.free_bufs({ params_bufs }); ctx->param_buf_pool.free_bufs({ params_bufs });
}); });
@ -298,9 +324,15 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context& ctx, wgpu::Com
} }
} }
static void ggml_backend_webgpu_buffer_memset(webgpu_context& ctx, wgpu::Buffer& buf, uint32_t value, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
wgpu::Buffer & buf,
uint32_t value,
size_t offset,
size_t size) {
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = {{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }}; std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
};
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
@ -342,20 +374,35 @@ static void ggml_webgpu_cpy(webgpu_context& ctx, ggml_tensor* src, ggml_tensor*
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = { std::vector<uint32_t> params = { ne,
ne, (uint32_t)(src_misalignment / ggml_type_size(src->type)), (uint32_t)(dst_misalignment / ggml_type_size(dst->type)), (uint32_t) (src_misalignment / ggml_type_size(src->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
// Convert byte-strides to element-strides // Convert byte-strides to element-strides
(uint32_t)(src->nb[0] / ggml_type_size(src->type)), (uint32_t)(src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
(uint32_t)(src->nb[2] / ggml_type_size(src->type)), (uint32_t)(src->nb[3] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
(uint32_t)(dst->nb[0] / ggml_type_size(dst->type)), (uint32_t)(dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
(uint32_t)(dst->nb[2] / ggml_type_size(dst->type)), (uint32_t)(dst->nb[3] / ggml_type_size(dst->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// Logical shape — same for both tensors even if permuted // Logical shape — same for both tensors even if permuted
(uint32_t)src->ne[0], (uint32_t)src->ne[1], (uint32_t)src->ne[2], (uint32_t)src->ne[3] (uint32_t) src->ne[0],
}; (uint32_t) src->ne[1],
(uint32_t) src->ne[2],
(uint32_t) src->ne[3] };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = ggml_backend_webgpu_tensor_buf(src), .offset = src_offset, .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, { .binding = 0,
{ .binding = 1, .buffer = ggml_backend_webgpu_tensor_buf(dst), .offset = dst_offset, .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } .buffer = ggml_backend_webgpu_tensor_buf(src),
.offset = src_offset,
.size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = dst_offset,
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
}; };
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
@ -381,12 +428,22 @@ static void ggml_webgpu_mul_mat(webgpu_context& ctx, ggml_tensor* src0, ggml_ten
}; };
std::vector<wgpu::BindGroupEntry> entries = { std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0, .buffer = ggml_backend_webgpu_tensor_buf(src0), .offset = ggml_backend_webgpu_tensor_offset(src0), .size = ggml_nbytes(src0) }, { .binding = 0,
{ .binding = 1, .buffer = ggml_backend_webgpu_tensor_buf(src1), .offset = ggml_backend_webgpu_tensor_offset(src1), .size = ggml_nbytes(src1) }, .buffer = ggml_backend_webgpu_tensor_buf(src0),
{ .binding = 2, .buffer = ggml_backend_webgpu_tensor_buf(dst), .offset = ggml_backend_webgpu_tensor_offset(dst), .size = ggml_nbytes(dst) } .offset = ggml_backend_webgpu_tensor_offset(src0),
.size = ggml_nbytes(src0) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(src1),
.offset = ggml_backend_webgpu_tensor_offset(src1),
.size = ggml_nbytes(src1) },
{ .binding = 2,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = ggml_backend_webgpu_tensor_offset(dst),
.size = ggml_nbytes(dst) }
}; };
uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; uint32_t wg_x =
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x); ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
} }
@ -406,11 +463,13 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor* node) {
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
return false; return false;
case GGML_OP_CPY: { case GGML_OP_CPY:
{
ggml_webgpu_cpy(ctx, src0, node); ggml_webgpu_cpy(ctx, src0, node);
break; break;
} }
case GGML_OP_MUL_MAT: { case GGML_OP_MUL_MAT:
{
ggml_webgpu_mul_mat(ctx, src0, src1, node); ggml_webgpu_mul_mat(ctx, src0, src1, node);
break; break;
} }
@ -468,13 +527,18 @@ static void* ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
return webgpu_ptr_base; return webgpu_ptr_base;
} }
static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor* tensor, uint8_t value, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
ggml_tensor * tensor,
uint8_t value,
size_t offset,
size_t size) {
if (size == 0) { if (size == 0) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
return; return;
} }
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
@ -483,8 +547,13 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
} }
static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor* tensor, const void* data, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_tensor * tensor,
const void * data,
size_t offset,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
@ -501,12 +570,18 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
} }
// memset the remaining bytes // memset the remaining bytes
ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); ggml_backend_webgpu_buffer_memset(
webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size);
} }
} }
static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); const ggml_tensor * tensor,
void * data,
size_t offset,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
@ -522,14 +597,16 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex); std::lock_guard<std::mutex> lock(webgpu_ctx->get_tensor_mutex);
if (webgpu_ctx->get_tensor_staging_buf == nullptr || if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
// Create a new staging buffer if it doesn't exist or is too small // Create a new staging buffer if it doesn't exist or is too small
if (webgpu_ctx->get_tensor_staging_buf) { if (webgpu_ctx->get_tensor_staging_buf) {
webgpu_ctx->get_tensor_staging_buf.Destroy(); webgpu_ctx->get_tensor_staging_buf.Destroy();
} }
ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, ggml_webgpu_create_buffer(device,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); webgpu_ctx->get_tensor_staging_buf,
final_size,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead,
"get_tensor_staging_buf");
} }
// Copy the data from the buffer to the staging buffer // Copy the data from the buffer to the staging buffer
@ -577,13 +654,17 @@ static const char* ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_
return ctx->device_name.c_str(); return ctx->device_name.c_str();
} }
static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")");
ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
wgpu::Buffer buf; wgpu::Buffer buf;
ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size, ggml_webgpu_create_buffer(ctx->webgpu_ctx->device,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); buf,
size,
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
"allocated_buffer");
ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf);
@ -650,7 +731,8 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context& webgpu_ctx) {
size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled // Size the bytes_per_thread so that the largest buffer size can be handled
webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; webgpu_ctx->memset_bytes_per_thread =
(webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads;
std::vector<wgpu::ConstantEntry> constants(2); std::vector<wgpu::ConstantEntry> constants(2);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = max_wg_size; constants[0].value = max_wg_size;
@ -686,17 +768,23 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
dev_desc.requiredLimits = &webgpu_ctx->limits; dev_desc.requiredLimits = &webgpu_ctx->limits;
dev_desc.requiredFeatures = webgpu_ctx->features.features; dev_desc.requiredFeatures = webgpu_ctx->features.features;
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount; dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, dev_desc.SetDeviceLostCallback(
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device); GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data); GGML_LOG_ERROR(
"ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
}); });
dev_desc.SetUncapturedErrorCallback( dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
GGML_UNUSED(device); GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data); GGML_LOG_ERROR(
"ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), message.data);
}); });
webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::AllowSpontaneous, webgpu_ctx->instance.WaitAny(
webgpu_ctx->adapter.RequestDevice(
&dev_desc,
wgpu::CallbackMode::AllowSpontaneous,
[webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
if (status != wgpu::RequestDeviceStatus::Success) { if (status != wgpu::RequestDeviceStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data);
@ -704,8 +792,7 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
} }
webgpu_ctx->device = std::move(device); webgpu_ctx->device = std::move(device);
}), }),
UINT64_MAX UINT64_MAX);
);
GGML_ASSERT(webgpu_ctx->device != nullptr); GGML_ASSERT(webgpu_ctx->device != nullptr);
// Initialize (compute) queue // Initialize (compute) queue
@ -746,7 +833,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ NULL, // defaults to false /* .is_host = */ NULL, // defaults to false
}, },
/* .device = */ dev, /* .device = */
dev,
/* .context = */ NULL, /* .context = */ NULL,
}; };
@ -818,7 +906,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
webgpu_context ctx = reg_ctx->webgpu_ctx; webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {}; wgpu::RequestAdapterOptions options = {};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char* message, void* userdata) { auto callback =
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) { if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return; return;
@ -826,7 +915,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter); *static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
}; };
void * userdata = &ctx->adapter; void * userdata = &ctx->adapter;
ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); ctx->instance.WaitAny(
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX);
GGML_ASSERT(ctx->adapter != nullptr); GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits); ctx->adapter.GetLimits(&ctx->limits);
@ -840,8 +930,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
device_ctx.device_name = GGML_WEBGPU_NAME; device_ctx.device_name = GGML_WEBGPU_NAME;
device_ctx.device_desc = std::string(info.description.data); device_ctx.device_desc = std::string(info.description.data);
GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n", GGML_LOG_INFO(
info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data); "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
"device_desc: %s\n",
info.vendorID,
info.vendor.data,
info.architecture.data,
info.deviceID,
info.device.data,
info.description.data);
// See GGML Backend Device Interface section // See GGML Backend Device Interface section
static ggml_backend_device device = { static ggml_backend_device device = {
@ -852,7 +949,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
return &device; return &device;
} }
static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
/* .get_name = */ ggml_backend_webgpu_reg_get_name, /* .get_name = */ ggml_backend_webgpu_reg_get_name,
/* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count, /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,