From 5a23695d5a78d2000a674fe54e5c4ecdff0a7089 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 03:46:41 -0700 Subject: [PATCH] ggml-webgpu: Update register tiling matmul to use f32 accumulation (#21644) * Update register tiling matmul to use f32 accumulation * fix profiling code * Fix register tiling matmul for chrome, i'm blaming dawn * Update batch tuning value for iOS * compile fix * Fix use of new load function --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 51 +++++++------------ .../wgsl-shaders/mul_mat_decls.tmpl | 35 +++++-------- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 12 ++--- .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 3 ++ 4 files changed, 40 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 634201bc64..8d0e109365 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -79,7 +79,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u #define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u #define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u #define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) @@ -97,14 +97,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* End Constants */ -static inline wgpu::CallbackMode ggml_webgpu_callback_mode() { -#ifdef __EMSCRIPTEN__ - return wgpu::CallbackMode::AllowProcessEvents; -#else - return wgpu::CallbackMode::AllowSpontaneous; -#endif -} - // This is a "fake" base pointer, since WebGPU buffers do not have pointers to // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -445,34 +437,25 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, } #ifdef __EMSCRIPTEN__ -// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures. EM_JS(int, ggml_webgpu_is_ios_browser, (), { const ua = navigator.userAgent; return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; }); #endif -static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) { +// TODO: these next two functions may want tuning across different platforms and workloads, +static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { #ifdef __EMSCRIPTEN__ + // iOS has very strict limits on the number of in-flight GPU commands, + // so we need to throttle to avoid failures. if (ggml_webgpu_is_ios_browser()) { return 1; } -#else - GGML_UNUSED(info); #endif - return UINT32_MAX; } -static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) { -#ifdef __EMSCRIPTEN__ - if (ggml_webgpu_is_ios_browser()) { - return 16; - } -#else - GGML_UNUSED(info); -#endif - +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() { return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; } @@ -482,7 +465,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( ctx->queue.OnSubmittedWorkDone( - ggml_webgpu_callback_mode(), + wgpu::CallbackMode::AllowSpontaneous, [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -502,7 +485,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, std::string callback_message; const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( - buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(), + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -542,15 +525,15 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { #endif #ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, - const std::vector & commands, - std::vector & futures) { +static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, + const std::vector & commands, + std::vector & futures) { for (const auto & command : commands) { auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(), + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); @@ -3428,7 +3411,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, ggml_webgpu_callback_mode(), + &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); @@ -3449,8 +3432,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } #endif ctx->webgpu_global_ctx->adapter.GetInfo(&info); - ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info); - ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support @@ -3501,7 +3484,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( - ggml_webgpu_callback_mode(), + wgpu::CallbackMode::AllowSpontaneous, [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; @@ -3535,7 +3518,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->adapter.RequestDevice( - &dev_desc, ggml_webgpu_callback_mode(), + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { if (status != wgpu::RequestDeviceStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 374137ff8e..56a76a6e6c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -502,12 +502,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let d = load_f16_at(&src0, block_byte_base); let dmin = load_f16_at(&src0, block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); - } - // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) // Inner loop over 2 shifts per group @@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); + let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -578,11 +574,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let d = load_f16_at(&src0, block_byte_base); let dmin = load_f16_at(&src0, block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); - } // The original loop processes elements in groups of 64 // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] @@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); + let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index b1da421a69..ee37e6d249 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -4,14 +4,14 @@ enable f16; #include "mul_mat_decls.tmpl" #ifdef VEC -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); } #endif #ifdef SCALAR -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return f32(acc[tm][tn]); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; } #endif @@ -98,7 +98,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - var acc: array, TILE_M>; + var acc: array, TILE_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { @@ -122,7 +122,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let src1_idx = src1_n * TILE_K + k_inner; let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += src0_tile[tm] * src1_val; + acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 9f9ef279f2..4151ce430b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -6,6 +6,9 @@ enable chromium_experimental_subgroup_matrix; #include "common_decls.tmpl" #include "mul_mat_decls.tmpl" +// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. +// See https://github.com/ggml-org/llama.cpp/issues/21602 + #ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4(