From aa2d278a11124bc0edcd103a1307ffde12985572 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 10 Mar 2026 09:14:27 -0700 Subject: [PATCH] ggml webgpu: faster normal quant and some k-quant matrix operations, better shader parameter handling (#20173) * K quant speedup (#20) * Basic JIT compilation for mul_mat, get_rows, and scale (#17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Start work on all-encompassing shader library * refactor argmax, set_rows * Refactor all but flashattention, mat mul * no gibberish, all k quants added, merged * vec memory fix * q6_k matching metal on my machine, tests passing * Set tile size for q6_k separately * Separate out fast shaders --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com> * Move towards writeBuffer for params * Move away from multiple buffers for set_rows errors, remove host buffer for parameter buffers, minor cleanups * Remove extra file * Formatting --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com> --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 96 ++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 459 ++++++------ .../wgsl-shaders/mul_mat_decls.tmpl | 673 +++++++++++++++++- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 1 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 290 +++++++- 5 files changed, 1250 insertions(+), 269 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 17c5e0fb51..3c38b1a230 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -42,11 +42,20 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 // Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 + // Must be multiple of 4 to work with vectorized paths, and must divide // mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 + +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 + +// Requires 32 threads per output (wg_size/outputs_per_wg == 32) +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 +// Requires at least two (and multiple of 2) k-quant blocks per tile +#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key { bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && + src_overlap == other.src_overlap; } }; @@ -749,6 +759,36 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "mul_mat_vec"; + // src0 type (matrix row) + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + + // For fast path we always dequantize from f16 inside the shader + defines.push_back("SRC0_INNER_TYPE=f16"); + break; + } + } + // src1 type (vector) switch (context.src1->type) { case GGML_TYPE_F32: @@ -763,39 +803,21 @@ class ggml_webgpu_shader_lib { GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); } - // src0 type (matrix row) - switch (context.src0->type) { - case GGML_TYPE_F32: - defines.push_back("SRC0_INNER_TYPE=f32"); - defines.push_back("MUL_ACC_FLOAT"); - break; - case GGML_TYPE_F16: - defines.push_back("SRC0_INNER_TYPE=f16"); - defines.push_back("MUL_ACC_FLOAT"); - break; - default: - { - // Quantized types: use helpers but accumulate in f16 - const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); - std::string src0_name = src0_traits->type_name; - std::string type_upper = src0_name; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - - defines.push_back("BYTE_HELPERS"); - defines.push_back("MUL_ACC_" + type_upper); - - // For fast path we always dequantize from f16 inside the shader - defines.push_back("SRC0_INNER_TYPE=f16"); - break; - } - } - // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; - uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type >= GGML_TYPE_Q2_K) { + tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); @@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, + .type = context.dst->type, + .op = context.dst->op, + .inplace = context.inplace, + .overlap = context.overlap, .src_overlap = context.src_overlap, }; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b2ef2d5901..ccc34cb153 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -20,12 +19,18 @@ #include #include #include -#include +#ifdef GGML_WEBGPU_GPU_PROFILE +# include +#endif +#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) +# include +#endif #include #include #include #include #include +#include #include #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) @@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps #endif /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 48u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u +#define WEBGPU_NUM_PARAM_BUFS 96u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -struct webgpu_pool_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { - std::vector free; + std::vector free; // The pool must be synchronized because // 1. The memset pool is shared globally by every ggml buffer, @@ -138,7 +137,6 @@ struct webgpu_buf_pool { size_t cur_pool_size; size_t max_pool_size; wgpu::Device device; - wgpu::BufferUsage host_buf_usage; wgpu::BufferUsage dev_buf_usage; size_t buf_size; bool should_grow; @@ -147,53 +145,47 @@ struct webgpu_buf_pool { int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage, bool should_grow = false, size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->host_buf_usage = host_buf_usage; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back({ host_buf, dev_buf }); + free.push_back(dev_buf); } } - webgpu_pool_bufs alloc_bufs() { + wgpu::Buffer alloc_bufs() { std::unique_lock lock(mutex); if (!free.empty()) { - webgpu_pool_bufs bufs = free.back(); + wgpu::Buffer buf = free.back(); free.pop_back(); - return bufs; + return buf; } // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; - wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - if (!(host_buf && dev_buf)) { + if (!dev_buf) { GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); } - return webgpu_pool_bufs{ host_buf, dev_buf }; + return dev_buf; } cv.wait(lock, [this] { return !free.empty(); }); - webgpu_pool_bufs bufs = free.back(); + wgpu::Buffer buf = free.back(); free.pop_back(); - return bufs; + return buf; } - void free_bufs(std::vector bufs) { + void free_bufs(std::vector bufs) { std::lock_guard lock(mutex); free.insert(free.end(), bufs.begin(), bufs.end()); cv.notify_all(); @@ -201,12 +193,9 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); - for (auto & bufs : free) { - if (bufs.host_buf) { - bufs.host_buf.Destroy(); - } - if (bufs.dev_buf) { - bufs.dev_buf.Destroy(); + for (auto & buf : free) { + if (buf) { + buf.Destroy(); } } free.clear(); @@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; - std::optional set_rows_error_bufs; + uint32_t num_kernels; + wgpu::CommandBuffer commands; + std::vector params_bufs; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -358,6 +346,13 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; +struct webgpu_submission { + wgpu::FutureWaitInfo submit_done; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector profile_futures; +#endif +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -366,7 +361,8 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; std::map> cpy_pipelines; // src_type, dst_type @@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** End WebGPU object initializations */ /** WebGPU Actions */ -static void erase_completed(std::vector & futures) { + +static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { + switch (status) { + case wgpu::WaitStatus::Success: + return true; + case wgpu::WaitStatus::TimedOut: + if (allow_timeout) { + return false; + } + GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); + return false; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + return false; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + return false; + } +} + +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { futures.erase(std::remove_if(futures.begin(), futures.end(), [](const wgpu::FutureWaitInfo & info) { return info.completed; }), futures.end()); } -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & futures, - bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. +static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, + std::vector & futures, + bool block) { if (futures.empty()) { return; } + uint64_t timeout_ms = block ? UINT64_MAX : 0; - while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX); - if (waitStatus == wgpu::WaitStatus::Error) { - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + if (block) { + while (!futures.empty()) { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { + ggml_backend_webgpu_erase_completed_futures(futures); + } } - if (futures[0].completed) { - futures.erase(futures.begin()); + } else { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { + ggml_backend_webgpu_erase_completed_futures(futures); + } + } +} +#endif + +// Wait for the queue to finish processing all submitted work +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & subs, + bool block = true) { + // If we have too many in-flight submissions, wait on the oldest one first. + if (subs.empty()) { + return; + } + while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); + if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); +#endif + subs.erase(subs.begin()); } } - if (futures.empty()) { + if (subs.empty()) { return; } if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - switch (waitStatus) { - case wgpu::WaitStatus::Success: - // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. - erase_completed(futures); - break; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - break; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - break; + for (auto & sub : subs) { + while (!sub.submit_done.completed) { + auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); + ggml_backend_webgpu_handle_wait_status(waitStatus); } +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); +#endif } + subs.clear(); } else { - // Poll once and return - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - switch (waitStatus) { - case wgpu::WaitStatus::Success: - // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. - erase_completed(futures); - break; - case wgpu::WaitStatus::TimedOut: - break; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - break; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - break; + // Poll each submit future once and remove completed submissions. + for (auto sub = subs.begin(); sub != subs.end();) { + auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); + ggml_backend_webgpu_handle_wait_status(waitStatus, true); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); + if (sub->submit_done.completed && sub->profile_futures.empty()) { +#else + if (sub->submit_done.completed) { +#endif + sub = subs.erase(sub); + } else { + ++sub; + } } } } @@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static std::vector ggml_backend_webgpu_submit( - webgpu_global_context ctx, - std::vector commands, - webgpu_buf_pool & param_buf_pool, - webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { +static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, + std::vector & commands, + webgpu_buf_pool & param_buf_pool) { std::vector command_buffers; - std::vector params_bufs; - std::vector set_rows_error_bufs; + std::vector params_bufs; + webgpu_submission submission; #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; #endif @@ -569,14 +601,9 @@ static std::vector ggml_backend_webgpu_submit( for (const auto & command : commands) { command_buffers.push_back(command.commands); params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - if (command.set_rows_error_bufs) { - set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); - } } ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - std::vector futures; - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { @@ -586,27 +613,7 @@ static std::vector ggml_backend_webgpu_submit( // Free the staged buffers param_buf_pool.free_bufs(params_bufs); }); - futures.push_back({ p_f }); - - for (const auto & bufs : set_rows_error_bufs) { - wgpu::Future f = bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - if (set_rows_error_buf_pool) { - set_rows_error_buf_pool->free_bufs({ bufs }); - } - } - }); - futures.push_back({ f }); - } + submission.submit_done = { p_f }; #ifdef GGML_WEBGPU_GPU_PROFILE for (const auto & command : commands) { @@ -623,14 +630,14 @@ static std::vector ggml_backend_webgpu_submit( // WebGPU timestamps are in ns; convert to ms double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; ctx->shader_gpu_time_ms[label] += elapsed_ms; - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); - futures.push_back({ f }); + submission.profile_futures.push_back({ f }); } #endif - return futures; + return submission; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list, - const std::optional & set_rows_error_bufs = std::nullopt) { + const std::vector> & workgroups_list) { GGML_ASSERT(pipelines.size() == params_list.size()); GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; - std::vector bind_groups; + std::vector params_bufs_list; + std::vector bind_groups; for (size_t i = 0; i < pipelines.size(); i++) { - webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); - - ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, - params_bufs.host_buf.GetSize()); - uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); - for (size_t j = 0; j < params_list[i].size(); j++) { - _params[j] = params_list[i][j]; - } - params_bufs.host_buf.Unmap(); + wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); std::vector entries = bind_group_entries_list[i]; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); + entries.push_back( + { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( } wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (const auto & params_bufs : params_bufs_list) { - encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); - } - - // If there are SET_ROWS operations in this submission, copy their error - // buffers to the host. - if (set_rows_error_bufs) { - encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, - set_rows_error_bufs->host_buf.GetSize()); + for (size_t i = 0; i < params_bufs_list.size(); i++) { + ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE @@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( webgpu_command result = {}; result.commands = commands; result.params_bufs = params_bufs_list; - result.set_rows_error_bufs = set_rows_error_bufs; result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; @@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, - uint32_t wg_y = 1, - std::optional set_rows_error_bufs = std::nullopt) { + uint32_t wg_y = 1) { return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline }, - { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); + { std::move(params) }, { std::move(bind_group_entries) }, + { { wg_x, wg_y } }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool); - ggml_backend_webgpu_wait(ctx, futures); + std::vector commands = { command }; + std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; + ggml_backend_webgpu_wait(ctx, sub); } /** End WebGPU Actions */ @@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "\nggml_webgpu: gpu breakdown:\n"; for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; - std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) + << pct << "%)\n"; } #endif @@ -978,14 +968,6 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); - std::optional error_bufs = std::nullopt; - if (decisions->i64_idx) { - error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); - if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - error_bufs->host_buf.Unmap(); - } - } - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -1018,8 +1000,10 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; if (decisions->i64_idx) { - entries.push_back( - { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); + entries.push_back({ .binding = 3, + .buffer = ctx->set_rows_dev_error_buf, + .offset = 0, + .size = ctx->set_rows_dev_error_buf.GetSize() }); } uint32_t threads; @@ -1029,8 +1013,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, - error_bufs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, use_fast = (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: + // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q6_K: use_fast = true; break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat + use_fast = !is_vec; + break; default: break; } @@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations - uint32_t wg_m, wg_n; + uint32_t wg_m; + uint32_t wg_n; if (decisions->use_subgroup_matrix) { uint32_t wg_m_sg_tile = decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; @@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else { // legacy - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); @@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, } static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; std::vector params = { @@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], dim, - (uint32_t)src0->ne[dim] + (uint32_t) src0->ne[dim] }; std::vector entries = { - { - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) - }, - { - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) - }, - { - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) - } + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2172,19 +2169,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_CLAMP: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_FILL: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQR: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQRT: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SIN: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_COS: return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: @@ -2192,7 +2182,6 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_ARGMAX: return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: - return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k return ggml_webgpu_argsort(ctx, src0, node); @@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector futures; - uint32_t num_batched_kernels = 0; + std::vector commands; + std::vector subs; + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + contains_set_rows = true; + } if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - std::vector compute_futures = ggml_backend_webgpu_submit( - ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.insert(futures.end(), compute_futures.begin(), compute_futures.end()); + num_batched_kernels = 0; + subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); + ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); commands.clear(); } } if (!commands.empty()) { - auto new_futures = - ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.insert(futures.end(), new_futures.begin(), new_futures.end()); + subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + commands.clear(); } - ggml_backend_webgpu_wait(ctx->global_ctx, futures); + // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. + if (contains_set_rows) { + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer set_rows_commands = encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &set_rows_commands); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); + } + + ggml_backend_webgpu_wait(ctx->global_ctx, subs); WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); - webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 5c1074ebc1..de60ebbcf2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -11,7 +11,7 @@ fn store_shmem(val: vec4, idx: u32) { shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#endif +#endif // VEC #ifdef SCALAR #define VEC_SIZE 1 @@ -23,7 +23,7 @@ fn store_shmem(val: vec4, idx: u32) { fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#endif +#endif // SCALAR #ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { @@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } -#endif +#endif // INIT_SRC0_SHMEM_FLOAT #ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { @@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } -#endif +#endif // INIT_SRC1_SHMEM_FLOAT #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; @@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } } -#endif +#endif // INIT_SRC0_SHMEM_Q4_0 + +#ifdef INIT_SRC0_SHMEM_Q4_1 +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_lo = f16(q_byte & 0xF) * d + m; + let q_hi = f16((q_byte >> 4) & 0xF) * d + m; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q4_1 + +#ifdef INIT_SRC0_SHMEM_Q5_0 +// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +// tile_k is defined as 32u, so blocks_k ends up being 1 always +override BLOCKS_K = TILE_K / BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let qh0 = src0[scale_idx + 1u]; + let qh1 = src0[scale_idx + 2u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; + + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d; + + shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight + shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q5_0 + +#ifdef INIT_SRC0_SHMEM_Q5_1 +// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +// tile_k is defined as 32u, so blocks_k ends up being 1 always +override BLOCKS_K = TILE_K / BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + let qh0 = src0[scale_idx + 2u]; + let qh1 = src0[scale_idx + 3u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + + let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; + + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m; + + shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight + shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q5_1 + +#ifdef INIT_SRC0_SHMEM_Q8_0 +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights +const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j+=2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + + let q_val = f16(q_byte) * d; + shmem[shmem_idx + j * 2 + k] = q_val; + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q8_0 + +#ifdef INIT_SRC0_SHMEM_Q8_1 +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights +const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j+=2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + + let q_val = f16(q_byte) * d + m; + shmem[shmem_idx + j * 2 + k] = q_val; + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q8_1 + +#ifdef INIT_SRC0_SHMEM_Q2_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 42u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + // Use standard thread layout instead of lane/row_group + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx + 40u]; + let dmin = src0[scale_idx + 41u]; + + // Decode the element at position k_in_block + let block_of_32 = k_in_block / 32u; + let pos_in_32 = k_in_block % 32u; + + let q_b_idx = (block_of_32 / 4u) * 32u; + let shift = (block_of_32 % 4u) * 2u; + let k = (pos_in_32 / 16u) * 16u; + let l = pos_in_32 % 16u; + + let is = k_in_block / 16u; + + let sc_0 = src0[scale_idx + 2u * (is / 4u)]; + let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; + let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc = get_byte(sc_packed, is % 4u); + + let dl = d * f16(sc & 0xFu); + let ml = dmin * f16(sc >> 4u); + + let q_idx = q_b_idx + k + l; + let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte = get_byte(q_packed, q_idx % 4u); + let qs_val = (q_byte >> shift) & 3u; + + let q_val = f16(qs_val) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif // INIT_SRC0_SHMEM_Q2_K + +#ifdef INIT_SRC0_SHMEM_Q3_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 55u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx + 54u]; + + // Load and unpack scales + let kmask1: u32 = 0x03030303u; + let kmask2: u32 = 0x0f0f0f0fu; + + var scale_vals: array; + for (var i: u32 = 0u; i < 4u; i++) { + let scale_0 = src0[scale_idx + 48u + (2u*i)]; + let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); + scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); + + // Load hmask and qs arrays + var hmask_vals: array; + for (var i: u32 = 0u; i < 8u; i++) { + let hmask_0 = src0[scale_idx + (2u*i)]; + let hmask_1 = src0[scale_idx + (2u*i) + 1u]; + hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + } + + var qs_vals: array; + for (var i: u32 = 0u; i < 16u; i++) { + let qs_0 = src0[scale_idx + 16u + (2u*i)]; + let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; + qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + } + + let half = k_in_block / 128u; // 0 or 1 + let pos_in_half = k_in_block % 128u; // 0-127 + let shift_group = pos_in_half / 32u; // 0-3 + let pos_in_32 = pos_in_half % 32u; // 0-31 + let k_group = pos_in_32 / 16u; // 0 or 1 + let l = pos_in_32 % 16u; // 0-15 + + let q_b_idx = half * 32u; // 0 or 32 + let shift = shift_group * 2u; // 0, 2, 4, 6 + let k = k_group * 16u; // 0 or 16 + let is = k_in_block / 16u; // 0-15 + + // m increments every 32 elements across entire 256 element block + let m_shift = k_in_block / 32u; // 0-7 + let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 + + let sc = get_byte(scale_vals[is / 4u], is % 4u); + let dl = d * (f16(sc) - 32.0); + + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + + let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); + + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3u; + + let q_val = (f16(qs_val) - f16(hm)) * dl; + shmem[elem_idx] = q_val; + } +} + +#endif // INIT_SRC0_SHMEM_Q3_K + +#ifdef INIT_SRC0_SHMEM_Q4_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 72u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let dmin = src0[scale_idx + 1u]; + + // Load packed scales + var scale_vals: array; + for (var i: u32 = 0u; i < 3u; i++) { + let scale_0 = src0[scale_idx + 2u + (2u*i)]; + let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + // Map k_in_block to loop structure: + // Outer loop over 64-element groups (alternating q_b_idx) + // Inner loop over 2 shifts per group + let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) + let pos_in_64 = k_in_block % 64u; // 0-63 + let shift_group = pos_in_64 / 32u; // 0 or 1 + let l = pos_in_64 % 32u; // 0-31 + + let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 + let shift = shift_group * 4u; // 0 or 4 + let is = k_in_block / 32u; // 0-7 + + var sc: u32; + var mn: u32; + + if (is < 4u) { + let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); + let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); + let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); + let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + let q_idx = q_b_idx + l; + let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let q_byte = get_byte(q_packed, q_idx % 4u); + let qs_val = (q_byte >> shift) & 0xFu; + + let q_val = f16(qs_val) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif // INIT_SRC0_SHMEM_Q4_K + +#ifdef INIT_SRC0_SHMEM_Q5_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 88u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let dmin = src0[scale_idx + 1u]; + + // Load packed scales + var scale_vals: array; + for (var i: u32 = 0u; i < 3u; i++) { + let scale_0 = src0[scale_idx + 2u + (2u*i)]; + let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + // The original loop processes elements in groups of 64 + // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] + // But u increments EVERY 32 elements (after each l loop) + let group_of_64 = k_in_block / 64u; // 0-3 + let pos_in_64 = k_in_block % 64u; // 0-63 + let shift_group = pos_in_64 / 32u; // 0 or 1 + let l = pos_in_64 % 32u; // 0-31 + + let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 + let shift = shift_group * 4u; // 0 or 4 + let is = k_in_block / 32u; // 0-7 + + // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) + let u_shift = k_in_block / 32u; // 0-7 + let u: u32 = 1u << u_shift; + + var sc: u32; + var mn: u32; + + if (is < 4u) { + let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); + let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); + let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); + let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + let q_idx = q_b_idx + l; + let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let q_byte = get_byte(q_packed, q_idx % 4u); + + let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; + let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; + let qh_packed = bitcast(vec2(qh_0, qh_1)); + + let qh_byte = get_byte(qh_packed, l % 4u); + + let qs_val = (q_byte >> shift) & 0xFu; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + + let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; + shmem[elem_idx] = q_val; + } +} + +#endif // INIT_SRC0_SHMEM_Q5_K + +#ifdef INIT_SRC0_SHMEM_Q6_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 105u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let half = k_in_block / 128u; + let pos_in_half = k_in_block % 128u; + let quarter = pos_in_half / 32u; + let l = pos_in_half % 32u; + + let ql_b_idx = half * 64u; + let qh_b_idx = half * 32u; + let sc_b_idx = half * 8u; + + // Load only ql13 word needed + let ql13_flat = ql_b_idx + l; + let ql13_word = ql13_flat / 4u; + let ql13 = bitcast(vec2( + src0[scale_idx + 2u * ql13_word], + src0[scale_idx + 2u * ql13_word + 1u] + )); + let ql13_b = get_byte(ql13, ql13_flat % 4u); + + // Load only ql24 word needed + let ql24_flat = ql_b_idx + l + 32u; + let ql24_word = ql24_flat / 4u; + let ql24 = bitcast(vec2( + src0[scale_idx + 2u * ql24_word], + src0[scale_idx + 2u * ql24_word + 1u] + )); + let ql24_b = get_byte(ql24, ql24_flat % 4u); + + // Load only qh word needed + let qh_flat = qh_b_idx + l; + let qh_word = qh_flat / 4u; + let qh = bitcast(vec2( + src0[scale_idx + 64u + 2u * qh_word], + src0[scale_idx + 64u + 2u * qh_word + 1u] + )); + let qh_b = get_byte(qh, qh_flat % 4u); + + let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); + let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); + let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); + let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); + + // Load only the scale word needed + let is = l / 16u; + let sc_idx = sc_b_idx + is + quarter * 2u; + let sc_word = sc_idx / 4u; + let sc = bitcast(vec2( + src0[scale_idx + 96u + 2u * sc_word], + src0[scale_idx + 96u + 2u * sc_word + 1u] + )); + let sc_val = get_byte_i32(sc, sc_idx % 4u); + + let d = src0[scale_idx + 104u]; + + var q_val: f16; + if (quarter == 0u) { + q_val = q1; + } else if (quarter == 1u) { + q_val = q2; + } else if (quarter == 2u) { + q_val = q3; + } else { + q_val = q4; + } + + shmem[elem_idx] = d * f16(sc_val) * q_val; + } +} +#endif // INIT_SRC0_SHMEM_Q6_K diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 761e3017c1..b1da421a69 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 { const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f9ea95e07b..94f4bae11f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,4 +1,3 @@ - enable f16; #include "common_decls.tmpl" @@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } #endif +#ifdef MUL_ACC_Q4_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 10u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = f32(src0[scale_idx + 1u]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q5_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 11u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let qh0 = src0[scale_idx + 1u]; + let qh1 = src0[scale_idx + 2u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + + local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + } + + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q5_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 12u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = src0[scale_idx + 1u]; + let qh0 = src0[scale_idx + 2u]; + let qh1 = src0[scale_idx + 3u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); + + local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + } + + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q8_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 17u; +const WEIGHTS_PER_F16 = 2u; +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + } + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q8_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 18u; +const WEIGHTS_PER_F16 = 2u; +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d + f32(m); + local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + } + } + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q6_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 105u; + +fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { + let aligned = byte_offset & ~3u; + let idx = bbase + aligned / 2u; + return bitcast(vec2(src0[idx], src0[idx + 1u])); +} + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} + +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let tid = tig / 2u; + let ix = tig % 2u; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let nb = tile_size / BLOCK_SIZE; + let k_block_start = k_outer / BLOCK_SIZE; + + // Aligned scale byte position (is can be odd) + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + var local_sum = 0.0; + + for (var i = ix; i < nb; i += 2u) { + let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; + + let d_raw = load_u32_at(bbase, 208u); + let d = f32(bitcast>(d_raw)[0]); + + let ql1_u32 = load_u32_at(bbase, q_offset_l); + let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); + let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); + let sc_u32_0 = load_u32_at(bbase, sc_base_byte); + let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let y_base = i * BLOCK_SIZE + y_offset + l; + let yl0 = f32(shared_vector[y_base]); + let yl1 = f32(shared_vector[y_base + 32u]); + let yl2 = f32(shared_vector[y_base + 64u]); + let yl3 = f32(shared_vector[y_base + 96u]); + + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += yl0 * dq0; + sums[1] += yl1 * dq1; + sums[2] += yl2 * dq2; + sums[3] += yl3 * dq3; + } + + local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + + return local_sum; +} +#endif + struct MulMatParams { offset_src0: u32, offset_src1: u32, @@ -191,4 +478,3 @@ fn main( dst[dst_idx / VEC_SIZE] = store_val(group_base); } } -