[WebGPU] Fix wait logic for inflight jobs (#20096)

* Enable tmate debugging for investigating thread safety issue

* Refactor wait and submit to operate on vector<wgpu::FutureWaitInfo>, and fix wait to delete only the future that is completed.

* Cleanup

* Remove clear change and run clang-format

* Cleanup
This commit is contained in:
Nikhil Jain 2026-03-04 11:54:55 -08:00 committed by GitHub
parent 541bf37622
commit 24d2ee0527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 65 additions and 35 deletions

View File

@ -123,11 +123,6 @@ struct webgpu_pool_bufs {
wgpu::Buffer dev_buf;
};
// The futures to wait on for a single queue submission
struct webgpu_submission_futures {
std::vector<wgpu::FutureWaitInfo> futures;
};
// Holds a pool of parameter buffers for WebGPU operations
struct webgpu_buf_pool {
std::vector<webgpu_pool_bufs> free;
@ -463,26 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
/** End WebGPU object initializations */
/** WebGPU Actions */
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
futures.erase(std::remove_if(futures.begin(), futures.end(),
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
futures.end());
}
// Wait for the queue to finish processing all submitted work
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<webgpu_submission_futures> & futures,
bool block = true) {
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
std::vector<wgpu::FutureWaitInfo> & futures,
bool block = true) {
// If we have too many in-flight submissions, wait on the oldest one first.
if (futures.empty()) {
return;
}
uint64_t timeout_ms = block ? UINT64_MAX : 0;
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
futures.erase(futures.begin());
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
if (waitStatus == wgpu::WaitStatus::Error) {
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
}
if (futures[0].completed) {
futures.erase(futures.begin());
}
}
size_t i = 0;
while (i < futures.size()) {
auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
if (futures.empty()) {
return;
}
if (block) {
while (!futures.empty()) {
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
break;
default:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
break;
}
}
} else {
// Poll once and return
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
switch (waitStatus) {
case wgpu::WaitStatus::Success:
futures.erase(futures.begin() + i);
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
erase_completed(futures);
break;
case wgpu::WaitStatus::TimedOut:
i++;
break;
case wgpu::WaitStatus::Error:
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
@ -525,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
}
#endif
static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
webgpu_global_context ctx,
std::vector<webgpu_command> commands,
webgpu_buf_pool & param_buf_pool,
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
std::vector<wgpu::CommandBuffer> command_buffers;
std::vector<webgpu_pool_bufs> params_bufs;
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
@ -600,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex
futures.push_back({ f });
}
#endif
return { futures };
return futures;
}
static webgpu_command ggml_backend_webgpu_build_multi(
@ -727,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
webgpu_command command =
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
ctx->memset_buf_pool) };
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
ggml_backend_webgpu_wait(ctx, futures);
}
@ -836,7 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
return flags;
}
@ -1153,8 +1182,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
};
// Calculate workgroup dimensions
uint32_t wg_x = 1;
uint32_t wg_y = 1;
uint32_t wg_x = 1;
uint32_t wg_y = 1;
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
if (use_fast && is_vec) {
@ -1410,7 +1439,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
uint32_t offset_merged_src0 = 0;
uint32_t offset_merged_src1 = 0;
if (flags.src_overlap) {
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
}
@ -1419,7 +1448,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
offset_merged_src0,
offset_merged_src1,
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
@ -2185,9 +2214,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
std::vector<webgpu_command> commands;
std::vector<webgpu_submission_futures> futures;
uint32_t num_batched_kernels = 0;
std::vector<webgpu_command> commands;
std::vector<wgpu::FutureWaitInfo> futures;
uint32_t num_batched_kernels = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
commands.push_back(*cmd);
@ -2195,9 +2224,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
}
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
num_batched_kernels = 0;
futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
&ctx->set_rows_error_buf_pool));
num_batched_kernels = 0;
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
// Process events and check for completed submissions
ctx->global_ctx->instance.ProcessEvents();
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
@ -2205,9 +2235,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
}
}
if (!commands.empty()) {
webgpu_submission_futures new_futures =
auto new_futures =
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
futures.push_back(new_futures);
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
}
ggml_backend_webgpu_wait(ctx->global_ctx, futures);