Fix thread-safe implementation

This commit is contained in:
Reese Levine 2025-07-31 11:02:08 -07:00
parent bfff27f130
commit b8012ecc0a
1 changed files with 49 additions and 38 deletions

View File

@ -13,7 +13,9 @@
#include <condition_variable>
#include <cstring>
#include <iostream>
#include <mutex>
#include <string>
#include <vector>
#ifdef GGML_WEBGPU_DEBUG
@ -61,7 +63,8 @@ struct webgpu_param_bufs {
struct webgpu_param_buf_pool {
std::vector<webgpu_param_bufs> free;
std::mutex mutex;
std::mutex mutex;
std::condition_variable cv;
void init(wgpu::Device device) {
@ -108,19 +111,18 @@ struct webgpu_param_buf_pool {
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
wgpu::Instance instance;
wgpu::Adapter adapter;
wgpu::Device device;
wgpu::Queue queue;
wgpu::Limits limits;
wgpu::SupportedFeatures features;
wgpu::Instance instance;
wgpu::Adapter adapter;
wgpu::Device device;
wgpu::Queue queue;
wgpu::Limits limits;
std::recursive_mutex submit_mutex;
std::recursive_mutex mutex;
std::mutex get_tensor_mutex;
std::mutex init_mutex;
bool device_init = false;
// Parameter buffer pool
bool device_init = false;
webgpu_param_buf_pool param_buf_pool;
wgpu::ComputePipeline memset_pipeline;
@ -134,36 +136,33 @@ struct webgpu_context_struct {
// Command buffers which need to be submitted
std::vector<wgpu::CommandBuffer> staged_command_bufs;
// Parameter buffers associated with the staged command buffers
std::vector<webgpu_param_bufs> staged_param_bufs;
std::vector<webgpu_param_bufs> staged_param_bufs;
};
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
struct ggml_backend_webgpu_reg_context {
webgpu_context webgpu_ctx;
size_t device_count;
const char * name;
size_t device_count;
const char * name;
};
struct ggml_backend_webgpu_device_context {
webgpu_context webgpu_ctx;
std::string device_name;
std::string device_desc;
std::string device_name;
std::string device_desc;
};
struct ggml_backend_webgpu_context {
webgpu_context webgpu_ctx;
std::string name;
std::string name;
};
struct ggml_backend_webgpu_buffer_context {
webgpu_context webgpu_ctx;
wgpu::Buffer buffer;
wgpu::Buffer buffer;
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
webgpu_ctx(std::move(ctx)),
@ -180,10 +179,13 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
const char * label,
const std::vector<wgpu::ConstantEntry> & constants = {}) {
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
wgpu::ShaderSourceWGSL shader_source;
shader_source.code = shader_code;
wgpu::ShaderModuleDescriptor shader_desc;
shader_desc.nextInChain = &shader_source;
shader_desc.nextInChain = &shader_source;
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
wgpu::ComputePipelineDescriptor pipeline_desc;
@ -210,8 +212,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
buffer_desc.usage = usage;
buffer_desc.label = label;
buffer_desc.mappedAtCreation = false;
// TODO: error handling
buffer = device.CreateBuffer(&buffer_desc);
buffer = device.CreateBuffer(&buffer_desc);
}
/** End WebGPU object initializations */
@ -231,8 +234,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
}
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
std::lock_guard<std::recursive_mutex> lock(ctx->submit_mutex);
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
ctx->staged_command_bufs.clear();
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
@ -274,6 +276,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
bool submit_imm = false) {
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
for (size_t i = 0; i < params.size(); i++) {
@ -315,7 +319,6 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
});
} else {
// Enqueue commands and only submit if we have enough staged commands
std::lock_guard<std::recursive_mutex> lock(ctx->submit_mutex);
ctx->staged_command_bufs.push_back(commands);
ctx->staged_param_bufs.push_back(params_bufs);
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
@ -540,10 +543,12 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
<< offset << ", " << size << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
// This is a trick to set all bytes of a u32 to the same 1 byte value.
uint32_t val32 = (uint32_t) value * 0x01010101;
uint32_t val32 = (uint32_t) value * 0x01010101;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
}
@ -559,13 +564,16 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
if (size % 4 != 0) {
// If size is not a multiple of 4, we need to memset the remaining bytes
size_t remaining_size = size % 4;
size_t remaining_size = size % 4;
// pack the remaining bytes into a uint32_t
uint32_t val32 = 0;
uint32_t val32 = 0;
for (size_t i = 0; i < remaining_size; i++) {
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
}
@ -613,8 +621,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
wgpu::CommandBuffer commands = encoder.Finish();
// Submit the command buffer to the queue
webgpu_ctx->queue.Submit(1, &commands);
{
std::lock_guard<std::recursive_mutex> submit_lock(webgpu_ctx->mutex);
// Submit the command buffer to the queue
webgpu_ctx->queue.Submit(1, &commands);
}
// Map the staging buffer to read the data
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
@ -628,7 +640,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
}
@ -764,10 +775,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
std::lock_guard<std::mutex> lock(webgpu_ctx->init_mutex);
if (!webgpu_ctx->device_init) {
// Initialize device
wgpu::DeviceDescriptor dev_desc;
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
wgpu::DeviceDescriptor dev_desc;
dev_desc.requiredLimits = &webgpu_ctx->limits;
dev_desc.requiredFeatures = webgpu_ctx->features.features;
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
dev_desc.requiredFeatures = required_features.data();
dev_desc.requiredFeatureCount = required_features.size();
dev_desc.SetDeviceLostCallback(
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
@ -920,7 +932,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits);
ctx->adapter.GetFeatures(&ctx->features);
wgpu::AdapterInfo info{};
ctx->adapter.GetInfo(&info);