From d006858316d4650bb4da0c6923294ccd741caefd Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 3 Apr 2026 11:40:14 -0700 Subject: [PATCH] ggml-webgpu: move from parameter buffer pool to single buffer with offsets (#21278) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs * Start work on removing parameter buffer pools * Simplify and optimize further * simplify profile futures * Fix stride * Try using a single command buffer per batch * formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 43 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 746 ++++++++---------- 2 files changed, 373 insertions(+), 416 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 1c56c68931..669d2cd53a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -437,12 +437,18 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ // Head-dim specializations used by the tuned vec f16 path. switch (key.head_dim_qk) { - case 64: return 2u; - case 96: return 4u; - case 128: return 1u; - case 192: return 2u; - case 576: return 2u; - default: return 1u; + case 64: + return 2u; + case 96: + return 4u; + case 128: + return 1u; + case 192: + return 2u; + case 576: + return 2u; + default: + return 1u; } } @@ -513,9 +519,9 @@ struct ggml_webgpu_flash_attn_blk_shader_lib_context { }; inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { std::vector defines; std::string variant = "flash_attn_vec_blk"; @@ -1857,9 +1863,8 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); if (context.key.use_vec) { q_tile = 1; kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); @@ -1885,14 +1890,14 @@ class ggml_webgpu_shader_lib { } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + pipeline.context = decisions; flash_attn_pipelines[context.key] = pipeline; return flash_attn_pipelines[context.key]; } @@ -1905,7 +1910,7 @@ class ggml_webgpu_shader_lib { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); flash_attn_blk_pipelines[context.key] = pipeline; return flash_attn_blk_pipelines[context.key]; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e53281bfbb..5c567dc0df 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -81,12 +81,10 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 96u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_NUM_PARAM_SLOTS \ + (WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10) // a few extra for safety, since some operations may need multiple slots #define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the -// parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -122,87 +120,45 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector free; +// Slot-based parameter arena for compute graph encoding. Each encoded kernel +// gets a unique uniform-buffer slice within the current batch, and the slot +// cursor is reset immediately after that batch is submitted. +struct webgpu_param_arena { + wgpu::Buffer buffer; + size_t slot_stride = 0; + size_t slot_size = 0; + uint32_t slot_count = 0; + uint32_t next_slot = 0; - // The pool must be synchronized because - // 1. The memset pool is shared globally by every ggml buffer, - // since allocating a pool per ggml buffer would consume too much memory. - // 2. For the per-thread buffer pools in webgpu_context, - // buffers are allocated and freed in Dawn callbacks, - // which can run on a different thread than the calling thread. - std::mutex mutex; - std::condition_variable cv; - size_t cur_pool_size; - size_t max_pool_size; - wgpu::Device device; - wgpu::BufferUsage dev_buf_usage; - size_t buf_size; - bool should_grow; + void init(wgpu::Device device, size_t slot_size, uint32_t slot_count, size_t alignment) { + this->slot_stride = ROUNDUP_POW2(slot_size, alignment); + this->slot_size = slot_size; + this->slot_count = slot_count; + this->next_slot = 0; - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - bool should_grow = false, - size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back(dev_buf); - } + ggml_webgpu_create_buffer(device, buffer, this->slot_stride * slot_count, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_param_arena"); } - wgpu::Buffer alloc_bufs() { - std::unique_lock lock(mutex); - if (!free.empty()) { - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; + size_t alloc_slot(size_t size) { + GGML_ASSERT(size <= slot_size); + if (next_slot >= slot_count) { + GGML_ABORT("ggml_webgpu: parameter arena exhausted while encoding a batch"); } - // Try growing the pool if no free buffers - if (free.empty() && cur_pool_size < max_pool_size && should_grow) { - cur_pool_size++; - lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - - if (!dev_buf) { - GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); - } - return dev_buf; - } - cv.wait(lock, [this] { return !free.empty(); }); - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; + return slot_stride * next_slot++; } - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } + void reset() { next_slot = 0; } void cleanup() { - std::lock_guard lock(mutex); - for (auto & buf : free) { - if (buf) { - buf.Destroy(); - } + if (buffer) { + buffer.Destroy(); + buffer = nullptr; } - free.clear(); } - ~webgpu_buf_pool() { this->cleanup(); } + ~webgpu_param_arena() { this->cleanup(); } }; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -269,10 +225,8 @@ struct webgpu_gpu_profile_buf_pool { }; #endif -struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; +struct webgpu_encoded_op { + uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -305,8 +259,8 @@ struct webgpu_global_context_struct { // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. std::recursive_mutex mutex; - webgpu_buf_pool memset_buf_pool; - std::map memset_pipelines; // variant or type index + wgpu::Buffer memset_params_buf; + webgpu_pipeline memset_pipeline; #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -332,6 +286,10 @@ struct webgpu_global_context_struct { this->get_tensor_staging_buf.Destroy(); this->get_tensor_staging_buf = nullptr; } + if (this->memset_params_buf) { + this->memset_params_buf.Destroy(); + this->memset_params_buf = nullptr; + } #ifdef GGML_WEBGPU_DEBUG if (this->debug_host_buf) { this->debug_host_buf.Destroy(); @@ -347,13 +305,6 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; -struct webgpu_submission { - wgpu::FutureWaitInfo submit_done; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; -#endif -}; - // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -361,9 +312,9 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_buf_pool param_buf_pool; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; size_t memset_bytes_per_thread; }; @@ -448,95 +399,34 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { - switch (status) { - case wgpu::WaitStatus::Success: - return true; - case wgpu::WaitStatus::TimedOut: - if (allow_timeout) { - return false; - } - GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); - return false; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - return false; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - return false; - } -} - #ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); -} - static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures, - bool block) { + std::vector & futures) { if (futures.empty()) { return; } - uint64_t timeout_ms = block ? UINT64_MAX : 0; - if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } - } - } else { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } + constexpr size_t max_futures_per_wait = 64; + + while (!futures.empty()) { + ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX); + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); } } #endif -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & subs, - bool block = true) { - if (subs.empty()) { - return; - } - - bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; - while (blocking_wait) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); -#endif - subs.erase(subs.begin()); - } - blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; - } - - if (subs.empty()) { - return; - } - - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - bool success = ggml_backend_webgpu_handle_wait_status(waitStatus, true); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (success && sub->profile_futures.empty()) { -#else - if (success) { -#endif - sub = subs.erase(sub); - } else { - ++sub; - } - } +static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { + ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, @@ -570,34 +460,10 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, - std::vector & commands, - webgpu_buf_pool & param_buf_pool) { - std::vector command_buffers; - std::vector params_bufs; - webgpu_submission submission; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector> pipeline_name_and_ts_bufs; -#endif - - for (const auto & command : commands) { - command_buffers.push_back(command.commands); - params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - } - ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - param_buf_pool.free_bufs(params_bufs); - }); - submission.submit_done = { p_f }; - #ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, + const std::vector & commands, + std::vector & futures) { for (const auto & command : commands) { auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; @@ -616,15 +482,15 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & // We can't unmap in here due to WebGPU reentrancy limitations. ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); - submission.profile_futures.push_back({ f }); + futures.push_back({ f }); } -#endif - return submission; } +#endif -static webgpu_command ggml_backend_webgpu_build_multi( +static webgpu_encoded_op ggml_backend_webgpu_build_multi( webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, + webgpu_param_arena & param_arena, + wgpu::CommandEncoder & encoder, const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, @@ -633,16 +499,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; + webgpu_encoded_op result = {}; std::vector bind_groups; + std::vector param_offsets; + result.num_kernels = pipelines.size(); for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); + const size_t param_size = params_list[i].size() * sizeof(uint32_t); + const size_t param_offset = param_arena.alloc_slot(param_size); std::vector entries = bind_group_entries_list[i]; uint32_t params_binding_num = entries.size(); - entries.push_back( - { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); + entries.push_back({ .binding = params_binding_num, + .buffer = param_arena.buffer, + .offset = param_offset, + .size = param_arena.slot_size }); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -650,13 +521,12 @@ static webgpu_command ggml_backend_webgpu_build_multi( bind_group_desc.entries = entries.data(); bind_group_desc.label = pipelines[i].name.c_str(); bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); - - params_bufs_list.push_back(params_bufs); + param_offsets.push_back(param_offset); } - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + for (size_t i = 0; i < param_offsets.size(); i++) { + ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(), + params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); @@ -682,29 +552,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( #ifdef GGML_WEBGPU_GPU_PROFILE encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); -#endif - - wgpu::CommandBuffer commands = encoder.Finish(); - webgpu_command result = {}; - result.commands = commands; - result.params_bufs = params_bufs_list; - result.num_kernels = pipelines.size(); -#ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; - // TODO: handle multiple pipeline names result.pipeline_name = pipelines.front().name; #endif return result; } -static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - webgpu_pipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context & ctx, + webgpu_param_arena & param_arena, + wgpu::CommandEncoder & encoder, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + uint32_t wg_y = 1) { + return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder, { pipeline }, @@ -724,10 +586,28 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = - ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector commands = { command }; - std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; + ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); + + entries.push_back( + { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = ctx->memset_pipeline.name.c_str(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->memset_pipeline.pipeline); + pass.SetBindGroup(0, bind_group); + pass.DispatchWorkgroups(wg_x, 1, 1); + pass.End(); + + wgpu::CommandBuffer command = encoder.Finish(); + std::vector commands = { command }; + ctx->queue.Submit(commands.size(), commands.data()); } /** End WebGPU Actions */ @@ -840,7 +720,10 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 return flags; } -static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, @@ -878,10 +761,14 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -940,10 +827,13 @@ static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; @@ -995,13 +885,14 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1056,13 +947,14 @@ static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1112,17 +1004,18 @@ static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * src3, - ggml_tensor * src4, - ggml_tensor * src5, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1197,13 +1090,14 @@ static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -1266,7 +1160,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1277,10 +1171,11 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1332,13 +1227,14 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1477,16 +1373,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +#ifndef __EMSCRIPTEN__ +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -1575,9 +1473,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = - std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1656,9 +1553,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, if (use_blk) { GGML_ASSERT(has_mask); - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; @@ -1729,8 +1626,10 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); } if (use_blk) { - split_entries.push_back( - { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = blk_buf, + .offset = blk_entries[1].offset, + .size = blk_size_bytes }); } split_entries.push_back( { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); @@ -1799,14 +1698,18 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, workgroups_list.push_back({ (uint32_t) nrows, 1u }); } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, entries_list, workgroups_list); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } +#endif // __EMSCRIPTEN__ -static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); @@ -1881,13 +1784,14 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1983,13 +1887,14 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -2039,10 +1944,13 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -2081,10 +1989,13 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -2124,14 +2035,16 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s }; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, + ggml_nrows(src)); } -static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2228,10 +2141,14 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2290,10 +2207,13 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2341,14 +2261,15 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2424,10 +2345,14 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, + ggml_nrows(dst)); } -static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2449,10 +2374,13 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2543,7 +2471,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr workgroups_list.push_back({ wg_x_init, wg_y_init }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, entries_list, workgroups_list); } @@ -2605,11 +2533,14 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, - workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, + entries_list, workgroups_list); } -static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2634,10 +2565,13 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2666,11 +2600,13 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * node) { if (ggml_is_empty(node)) { return std::nullopt; } @@ -2693,18 +2629,18 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - return ggml_webgpu_cpy(ctx, src0, node); + return ggml_webgpu_cpy(ctx, encoder, src0, node); case GGML_OP_SET: - return ggml_webgpu_set(ctx, src0, src1, node); + return ggml_webgpu_set(ctx, encoder, src0, src1, node); case GGML_OP_SET_ROWS: - return ggml_webgpu_set_rows(ctx, src0, src1, node); + return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node); case GGML_OP_GET_ROWS: - return ggml_webgpu_get_rows(ctx, src0, src1, node); + return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); case GGML_OP_MUL_MAT: - return ggml_webgpu_mul_mat(ctx, src0, src1, node); + return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ - return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); + return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); #else return std::nullopt; #endif @@ -2712,22 +2648,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return ggml_webgpu_binary_op(ctx, src0, src1, node); + return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node); case GGML_OP_CONCAT: - return ggml_webgpu_concat(ctx, src0, src1, node); + return ggml_webgpu_concat(ctx, encoder, src0, src1, node); case GGML_OP_REPEAT: - return ggml_webgpu_repeat(ctx, src0, node); + return ggml_webgpu_repeat(ctx, encoder, src0, node); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return ggml_webgpu_row_norm(ctx, src0, node); + return ggml_webgpu_row_norm(ctx, encoder, src0, node); case GGML_OP_ROPE: - return ggml_webgpu_rope(ctx, src0, src1, src2, node); + return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node); case GGML_OP_GLU: - return ggml_webgpu_glu(ctx, src0, src1, node); + return ggml_webgpu_glu(ctx, encoder, src0, src1, node); case GGML_OP_SCALE: - return ggml_webgpu_scale(ctx, src0, node); + return ggml_webgpu_scale(ctx, encoder, src0, node); case GGML_OP_SOFT_MAX: - return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); + return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node); case GGML_OP_UNARY: case GGML_OP_CLAMP: case GGML_OP_FILL: @@ -2738,26 +2674,27 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_COS: case GGML_OP_DIAG: case GGML_OP_TRI: - return ggml_webgpu_unary_op(ctx, src0, node); + return ggml_webgpu_unary_op(ctx, encoder, src0, node); case GGML_OP_SOLVE_TRI: - return ggml_webgpu_solve_tri(ctx, src0, src1, node); + return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node); case GGML_OP_SSM_CONV: - return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node); case GGML_OP_GATED_DELTA_NET: - return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); + return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5], + node); case GGML_OP_PAD: - return ggml_webgpu_pad(ctx, src0, node); + return ggml_webgpu_pad(ctx, encoder, src0, node); case GGML_OP_ARGMAX: - return ggml_webgpu_argmax(ctx, src0, node); + return ggml_webgpu_argmax(ctx, encoder, src0, node); case GGML_OP_ARGSORT: case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k - return ggml_webgpu_argsort(ctx, src0, node); + return ggml_webgpu_argsort(ctx, encoder, src0, node); case GGML_OP_CUMSUM: - return ggml_webgpu_cumsum(ctx, src0, node); + return ggml_webgpu_cumsum(ctx, encoder, src0, node); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: - return ggml_webgpu_sum_rows(ctx, src0, node); + return ggml_webgpu_sum_rows(ctx, encoder, src0, node); default: return std::nullopt; } @@ -2771,30 +2708,42 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector subs; - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; + std::vector commands; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector profile_futures; +#endif + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; + wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); - // Process events and check for completed submissions - ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); + num_batched_kernels = 0; + wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &batch_commands); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); +#endif + ctx->param_arena.reset(); commands.clear(); + batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); } } if (!commands.empty()) { - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &batch_commands); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); +#endif + ctx->param_arena.reset(); commands.clear(); } @@ -2805,6 +2754,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.GetSize()); wgpu::CommandBuffer set_rows_commands = encoder.Finish(); ctx->global_ctx->queue.Submit(1, &set_rows_commands); + } + + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + + if (contains_set_rows) { ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, ctx->set_rows_host_error_buf.GetSize()); const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); @@ -2814,7 +2768,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.Unmap(); } - ggml_backend_webgpu_wait(ctx->global_ctx, subs); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures); +#endif WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3063,18 +3019,16 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = - (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (V->type == K->type); + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && + kv_vec_type_supported && (V->type == K->type); if (use_vec) { const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = - (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; if (!kv_direct) { bytes_per_kv += std::max(Q->ne[0], V->ne[0]); @@ -3084,10 +3038,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = - ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; if (kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { @@ -3097,30 +3050,30 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { nwg <<= 1; } nwg = std::min(nwg, vec_nwg_cap); - const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t align = + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; if (nwg > 1u) { const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( + const size_t tmp_size_bytes = ROUNDUP_POW2( (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; } if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = - (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); res += blk_size_bytes + align; } @@ -3195,11 +3148,11 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = ctx->capabilities.memset_bytes_per_thread; - ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { @@ -3331,9 +3284,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); - ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, ctx->webgpu_global_ctx->memset_params_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -3357,9 +3310,8 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); - webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); + webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");