ggml webgpu: Clean up per-thread parameter buffer pool and job submission logic (#19772)

* Allow webgpu_buf_pool to resize if needed, remove inflight_threads, and replace inflight_threads with num_kernels for submission

* Run clang-format

* Keep track of num batched kernels that have not been submitted yet

* Run clang-format

* Increase buf pool max size

* Increase param buf pool init size

* Remove webgpu buf pool resizing

* Merge with master

* Add buffer pool growth

* Move buffer pool growth outside of lock

* Reduce max pool size to 32

* Run clang-format

* Only resize param buf pool
This commit is contained in:
Nikhil Jain 2026-03-02 10:23:34 -08:00 committed by GitHub
parent 36a7a6589c
commit 4d828bd1ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 47 additions and 18 deletions

View File

@ -133,12 +133,28 @@ struct webgpu_buf_pool {
// which can run on a different thread than the calling thread.
std::mutex mutex;
std::condition_variable cv;
size_t cur_pool_size;
size_t max_pool_size;
wgpu::Device device;
wgpu::BufferUsage host_buf_usage;
wgpu::BufferUsage dev_buf_usage;
size_t buf_size;
bool should_grow;
void init(wgpu::Device device,
int num_bufs,
size_t buf_size,
wgpu::BufferUsage dev_buf_usage,
wgpu::BufferUsage host_buf_usage) {
wgpu::BufferUsage host_buf_usage,
bool should_grow = false,
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
this->max_pool_size = max_pool_size;
this->cur_pool_size = num_bufs;
this->device = device;
this->host_buf_usage = host_buf_usage;
this->dev_buf_usage = dev_buf_usage;
this->buf_size = buf_size;
this->should_grow = should_grow;
for (int i = 0; i < num_bufs; i++) {
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
@ -150,6 +166,25 @@ struct webgpu_buf_pool {
webgpu_pool_bufs alloc_bufs() {
std::unique_lock<std::mutex> lock(mutex);
if (!free.empty()) {
webgpu_pool_bufs bufs = free.back();
free.pop_back();
return bufs;
}
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
wgpu::Buffer host_buf;
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
if (!(host_buf && dev_buf)) {
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
}
return webgpu_pool_bufs{ host_buf, dev_buf };
}
cv.wait(lock, [this] { return !free.empty(); });
webgpu_pool_bufs bufs = free.back();
free.pop_back();
@ -243,6 +278,7 @@ struct webgpu_gpu_profile_buf_pool {
#endif
struct webgpu_command {
uint32_t num_kernels;
wgpu::CommandBuffer commands;
std::vector<webgpu_pool_bufs> params_bufs;
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
@ -280,7 +316,6 @@ struct webgpu_global_context_struct {
webgpu_buf_pool memset_buf_pool;
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
std::atomic_uint inflight_threads = 0;
#ifdef GGML_WEBGPU_CPU_PROFILE
// Profiling: labeled CPU time in ms (total)
@ -426,13 +461,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission_futures> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first. If
// there are many threads, inflight_max may be 0, meaning that we must wait on
// all futures.
uint64_t timeout_ms = block ? UINT64_MAX : 0;
uint32_t inflight_threads = ctx->inflight_threads;
uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
while (futures.size() >= inflight_max && futures.size() > 0) {
// If we have too many in-flight submissions, wait on the oldest one first.
uint64_t timeout_ms = block ? UINT64_MAX : 0;
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
futures.erase(futures.begin());
}
@ -651,6 +682,7 @@ static webgpu_command ggml_backend_webgpu_build_multi(
result.commands = commands;
result.params_bufs = params_bufs_list;
result.set_rows_error_bufs = set_rows_error_bufs;
result.num_kernels = pipelines.size();
#ifdef GGML_WEBGPU_GPU_PROFILE
result.timestamp_query_bufs = ts_bufs;
// TODO: handle multiple pipeline names
@ -2081,19 +2113,17 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
ctx->global_ctx->inflight_threads++;
std::vector<webgpu_command> commands;
std::vector<webgpu_submission_futures> futures;
uint32_t num_batched_kernels = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
num_batched_kernels += cmd.value().num_kernels;
}
// compute the batch size based on the number of inflight threads
uint32_t inflight_threads = ctx->global_ctx->inflight_threads;
uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
if (commands.size() >= batch_size) {
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
num_batched_kernels = 0;
futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
&ctx->set_rows_error_buf_pool));
// Process events and check for completed submissions
@ -2109,7 +2139,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
}
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
ctx->global_ctx->inflight_threads--;
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
return GGML_STATUS_SUCCESS;
}
@ -2727,7 +2756,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device);
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,