From 4d828bd1ab52773ba9570cc008cf209eb4a8b2f5 Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Mon, 2 Mar 2026 10:23:34 -0800 Subject: [PATCH] ggml webgpu: Clean up per-thread parameter buffer pool and job submission logic (#19772) * Allow webgpu_buf_pool to resize if needed, remove inflight_threads, and replace inflight_threads with num_kernels for submission * Run clang-format * Keep track of num batched kernels that have not been submitted yet * Run clang-format * Increase buf pool max size * Increase param buf pool init size * Remove webgpu buf pool resizing * Merge with master * Add buffer pool growth * Move buffer pool growth outside of lock * Reduce max pool size to 32 * Run clang-format * Only resize param buf pool --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 65 ++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4dc56e1dc5..913cf7f882 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -133,12 +133,28 @@ struct webgpu_buf_pool { // which can run on a different thread than the calling thread. std::mutex mutex; std::condition_variable cv; + size_t cur_pool_size; + size_t max_pool_size; + wgpu::Device device; + wgpu::BufferUsage host_buf_usage; + wgpu::BufferUsage dev_buf_usage; + size_t buf_size; + bool should_grow; void init(wgpu::Device device, int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { + wgpu::BufferUsage host_buf_usage, + bool should_grow = false, + size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->host_buf_usage = host_buf_usage; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { wgpu::Buffer host_buf; wgpu::Buffer dev_buf; @@ -150,6 +166,25 @@ struct webgpu_buf_pool { webgpu_pool_bufs alloc_bufs() { std::unique_lock lock(mutex); + if (!free.empty()) { + webgpu_pool_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + // Try growing the pool if no free buffers + if (free.empty() && cur_pool_size < max_pool_size && should_grow) { + cur_pool_size++; + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); + + if (!(host_buf && dev_buf)) { + GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); + } + return webgpu_pool_bufs{ host_buf, dev_buf }; + } cv.wait(lock, [this] { return !free.empty(); }); webgpu_pool_bufs bufs = free.back(); free.pop_back(); @@ -243,6 +278,7 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { + uint32_t num_kernels; wgpu::CommandBuffer commands; std::vector params_bufs; std::optional set_rows_error_bufs; @@ -280,7 +316,6 @@ struct webgpu_global_context_struct { webgpu_buf_pool memset_buf_pool; std::map memset_pipelines; // variant or type index - std::atomic_uint inflight_threads = 0; #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -426,13 +461,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If - // there are many threads, inflight_max may be 0, meaning that we must wait on - // all futures. - uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint32_t inflight_threads = ctx->inflight_threads; - uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); - while (futures.size() >= inflight_max && futures.size() > 0) { + // If we have too many in-flight submissions, wait on the oldest one first. + uint64_t timeout_ms = block ? UINT64_MAX : 0; + while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); } @@ -651,6 +682,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( result.commands = commands; result.params_bufs = params_bufs_list; result.set_rows_error_bufs = set_rows_error_bufs; + result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; // TODO: handle multiple pipeline names @@ -2081,19 +2113,17 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - ctx->global_ctx->inflight_threads++; - std::vector commands; std::vector futures; + uint32_t num_batched_kernels = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); + num_batched_kernels += cmd.value().num_kernels; } - // compute the batch size based on the number of inflight threads - uint32_t inflight_threads = ctx->global_ctx->inflight_threads; - uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); - if (commands.size() >= batch_size) { + + if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + num_batched_kernels = 0; futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool)); // Process events and check for completed submissions @@ -2109,7 +2139,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } ggml_backend_webgpu_wait(ctx->global_ctx, futures); - ctx->global_ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2727,7 +2756,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,