diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 334919e589..b2ef2d5901 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -123,11 +123,6 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector futures; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -463,26 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** End WebGPU object initializations */ /** WebGPU Actions */ +static void erase_completed(std::vector & futures) { + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); +} // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & futures, - bool block = true) { +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & futures, + bool block = true) { // If we have too many in-flight submissions, wait on the oldest one first. + if (futures.empty()) { + return; + } uint64_t timeout_ms = block ? UINT64_MAX : 0; while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); - futures.erase(futures.begin()); + auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX); + if (waitStatus == wgpu::WaitStatus::Error) { + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + } + if (futures[0].completed) { + futures.erase(futures.begin()); + } } - size_t i = 0; - while (i < futures.size()) { - auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + + if (futures.empty()) { + return; + } + + if (block) { + while (!futures.empty()) { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. + erase_completed(futures); + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; + } + } + } else { + // Poll once and return + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); switch (waitStatus) { case wgpu::WaitStatus::Success: - futures.erase(futures.begin() + i); + // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. + erase_completed(futures); break; case wgpu::WaitStatus::TimedOut: - i++; break; case wgpu::WaitStatus::Error: GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); @@ -525,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, - std::vector commands, - webgpu_buf_pool & param_buf_pool, - webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { +static std::vector ggml_backend_webgpu_submit( + webgpu_global_context ctx, + std::vector commands, + webgpu_buf_pool & param_buf_pool, + webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { std::vector command_buffers; std::vector params_bufs; std::vector set_rows_error_bufs; @@ -600,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex futures.push_back({ f }); } #endif - return { futures }; + return futures; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -727,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, - ctx->memset_buf_pool) }; + auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool); ggml_backend_webgpu_wait(ctx, futures); } @@ -836,7 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 binary_overlap_flags flags = {}; flags.inplace = ggml_webgpu_tensor_equal(src0, dst); flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); return flags; } @@ -1153,8 +1182,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { @@ -1410,7 +1439,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, uint32_t offset_merged_src0 = 0; uint32_t offset_merged_src1 = 0; if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); } @@ -1419,7 +1448,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), offset_merged_src0, offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), @@ -2185,9 +2214,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector futures; - uint32_t num_batched_kernels = 0; + std::vector commands; + std::vector futures; + uint32_t num_batched_kernels = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); @@ -2195,9 +2224,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, - &ctx->set_rows_error_buf_pool)); + num_batched_kernels = 0; + std::vector compute_futures = ggml_backend_webgpu_submit( + ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); + futures.insert(futures.end(), compute_futures.begin(), compute_futures.end()); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); @@ -2205,9 +2235,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } } if (!commands.empty()) { - webgpu_submission_futures new_futures = + auto new_futures = ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.push_back(new_futures); + futures.insert(futures.end(), new_futures.begin(), new_futures.end()); } ggml_backend_webgpu_wait(ctx->global_ctx, futures);