From e4fed9d08de1d33ab51748880d38b62b9968dc2e Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Fri, 10 Apr 2026 13:52:01 -0400 Subject: [PATCH] ggml-webgpu: address quantization precision and backend lifecycle managment (#21521) * ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops --------- Co-authored-by: Jeremy J. Hartmann --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 55 ++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 37 +++- .../wgsl-shaders/common_decls.tmpl | 139 +++---------- .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 189 +++++++++++------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 161 ++++++++------- .../wgsl-shaders/mul_mat_decls.tmpl | 78 ++++---- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 46 ++--- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 8 +- 8 files changed, 383 insertions(+), 330 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c10157766d..3de6258c74 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + switch (key.src_type) + { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); @@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib { variant += "_"; variant += type_str; - defines.push_back(std::string("SRC_TYPE=") + type_str); defines.push_back("DST_TYPE=f32"); if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || @@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib { break; default: { - // quantized types std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - defines.push_back(std::string("SRC0_TYPE=") + src0_name); + switch (context.src0->type) + { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC0_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index edfc657917..3b894a9b9c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* End Constants */ +static inline wgpu::CallbackMode ggml_webgpu_callback_mode() { +#ifdef __EMSCRIPTEN__ + return wgpu::CallbackMode::AllowProcessEvents; +#else + return wgpu::CallbackMode::AllowSpontaneous; +#endif +} + // This is a "fake" base pointer, since WebGPU buffers do not have pointers to // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, + ggml_webgpu_callback_mode(), [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, std::string callback_message; const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( - buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(), [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); + if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, + ctx->debug_host_buf.GetSize())) { + GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n"); + return; + } const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); @@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & auto ts_bufs = command.timestamp_query_bufs; wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(), [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); @@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, + &options, ggml_webgpu_callback_mode(), [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); @@ -3491,8 +3503,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + ggml_webgpu_callback_mode(), + [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; } @@ -3525,7 +3537,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + &dev_desc, ggml_webgpu_callback_mode(), [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { if (status != wgpu::RequestDeviceStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); @@ -4046,6 +4058,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 0; + // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend + // registry. Recreating it on repeated registry lookups can invalidate + // adapter/device references that are still held by the backend/device layer. + if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + return ® + } + wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); @@ -4063,11 +4082,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); ctx.webgpu_global_ctx->instance = std::move(inst); + // Probe for adapter support wgpu::Adapter adapter; if (ctx.webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - // probe for adapter support ctx.webgpu_global_ctx->instance.WaitAny( ctx.webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index feb0bca3f8..0d3501c34a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,35 +9,43 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_src0_u16_at(byte_offset: u32) -> u32 { - let word = src0[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; +fn load_u16_at( + buf: ptr, read_write>, + byte_offset: u32) -> u32 { + let word = buf[byte_offset / 4]; + let shift = (byte_offset & 0x2) * 8; + return (word >> shift) & 0xFFFF; } -fn load_src0_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = src0[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = src0[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); +fn load_u32_at( + buf: ptr, read_write>, + byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4; + let shift = (byte_offset & 0x3) * 8; + let lo = buf[word_idx]; + let hi = buf[word_idx + 1]; + let shifted = (lo >> shift) | (hi << (32 - shift)); + return select(shifted, lo, shift == 0); } -fn load_src0_f16_at(byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_src0_u16_at(byte_offset)); +fn load_f16_at( + buf: ptr, read_write>, + byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at(buf, byte_offset)); return f16(packed[0]); } + +fn load_f16_as_f32_at( + buf: ptr, read_write>, + byte_offset: u32) -> f32 { + let word = buf[byte_offset / 4]; + let shift = (byte_offset & 0x2) * 8; + let d_bits = (word >> shift) & 0xFFFF; + return unpack2x16float(d_bits)[0]; +} #endif -#ifdef Q4_0_T -struct q4_0 { - d: f16, - qs: array -}; -#endif + #ifdef Q4_1_T struct q4_1 { @@ -47,13 +55,6 @@ struct q4_1 { }; #endif -#ifdef Q5_0_T -struct q5_0 { - d: f16, - qh: array, - qs: array -}; -#endif #ifdef Q5_1_T struct q5_1 { @@ -64,12 +65,6 @@ struct q5_1 { }; #endif -#ifdef Q8_0_T -struct q8_0 { - d: f16, - qs: array -}; -#endif #ifdef Q8_1_T struct q8_1 { @@ -88,14 +83,6 @@ struct q2_K { }; #endif -#ifdef Q3_K_T -struct q3_K { - hmask: array, - qs: array, - scales: array, - d: f16 -}; -#endif #if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { @@ -132,64 +119,6 @@ struct q5_K { }; #endif -#ifdef Q6_K_T -struct q6_K { - ql: array, - qh: array, - scales: array, - d: f16 -}; -#endif - -#ifdef IQ2_XXS_T -struct iq2_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ2_XS_T -struct iq2_xs { - d: f16, - qs: array, - scales: array -}; -#endif - -#ifdef IQ2_S_T -struct iq2_s { - d: f16, - qs: array, - qh: array, - scales: array -}; -#endif - -#ifdef IQ3_XXS_T -struct iq3_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ3_S_T -struct iq3_s { - d: f16, - qs: array, - qh: array, - signs: array, - scales: array -}; -#endif - -#ifdef IQ1_S_T -struct iq1_s { - d: f16, - qs: array, - qh: array -}; -#endif - #ifdef IQ1_M_T struct iq1_m { qs: array, @@ -198,17 +127,9 @@ struct iq1_m { }; #endif -#ifdef IQ4_NL_T -struct iq4_nl { - d: f16, - qs: array, -}; -#endif - #ifdef IQ4_XS_T struct iq4_xs { - d: f16, - scales_h: f16, + d_scales_h: u32, scales_l: u32, qs: array }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index d9eb6a3567..3c8b84c9ac 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_0 = src[src_base + offset]; - let d = f32(block_q4_0.d); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; + dst[dst_offset + 16u] = q_hi; } } } @@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_0 = src[src_base + offset]; - let d = f32(block_q5_0.d); - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + let qh_packed = load_u32_at(&src, block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at(&src, q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; dst[dst_offset + 16] = q_hi; @@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q8_0 = src[src_base + offset]; - let d = f32(block_q8_0.d); - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + for (var j: u32 = 0u; j < 8u; j++) { + let q_byte_offset = block_byte_base + 2u + j * 4u; + let q_packed = load_u32_at(&src, q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; + let dst_offset = dst_base + offset * 32u + j * 4u + k; dst[dst_offset] = q_val; } } @@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at(&src, block_byte_base + 108); + + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at(&src, block_byte_base + 96); + scale_vals[1] = load_u32_at(&src, block_byte_base + 100); + scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; var is: u32 = 0; var m: u32 = 1; + // 2 halves of the block (128 elements each) for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { // 4 groups (each group has 2 blocks of 16 elements) @@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sc = get_byte(scale_vals[is / 4], is % 4); is++; let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { + + for (var l: u32 = 0; l < 16; l++) { let q_idx = q_b_idx + k + l; let hm_idx = k + l; let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); let qs_val = (q_byte >> shift) & 3; dst[dst_i] = (f32(qs_val) - hm) * dl; @@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at(&src, block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + for (var i: u32 = 0; i < 16u; i++) { + qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at(&src, aux0_offset); + let aux1 = load_u32_at(&src, aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif + + #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at(&src, block_byte_base + 66), + load_u32_at(&src, block_byte_base + 70) ); + for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); let db = array( @@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at(&src, block_byte_base + 66); + qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at(&src, block_byte_base + 74); + scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); let db = array( @@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at(&src, sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at(&src, block_byte_base + 66), + load_u32_at(&src, block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at(&src, block_byte_base + 106); + for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); let db = array( @@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 5b9f5b3622..fdabaf09b2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_0 = src0[src0_idx_base + offset]; - let d = f32(block_q4_0.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_0 = src0[src0_idx_base + offset]; - let d = f32(block_q5_0.d); + let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let qh_packed = load_u32_at(&src0, block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_0 = src0[src0_idx_base + offset]; - let d = f32(block_q8_0.d); + let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; @@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at(&src0, block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); + scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); + scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at(&src0, aux0_offset); + let aux1 = load_u32_at(&src0, aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at(&src0, block_byte_base + 66), + load_u32_at(&src0, block_byte_base + 70) ); + var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); @@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); + qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); + scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at(&src0, sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at(&src0, block_byte_base + 66), + load_u32_at(&src0, block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index ea91c13468..374137ff8e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); + let d = load_f16_at(&src0, block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let qh_packed = load_u32_at(&src0, block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); + let qh_packed = load_u32_at(&src0, block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); + let d = load_f16_at(&src0, block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base + 80u); - let dmin = load_src0_f16_at(block_byte_base + 82u); + let d = load_f16_at(&src0, block_byte_base + 80u); + let dmin = load_f16_at(&src0, block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base + 108u); + let d = load_f16_at(&src0, block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let dmin = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let dmin = load_f16_at(&src0, block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); } // Map k_in_block to loop structure: @@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let dmin = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let dmin = load_f16_at(&src0, block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); } // The original loop processes elements in groups of 64 @@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_src0_u32_at(block_byte_base + ql13_flat); + let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_src0_u32_at(block_byte_base + ql24_flat); + let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat); + let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx); + let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_src0_f16_at(block_byte_base + 208u); + let d = load_f16_at(&src0, block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6525f23bdf..6f6bcaf794 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); + let d = f32(load_f16_at(&src0, block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = f32(load_f16_at(&src0, block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let qh_packed = load_u32_at(&src0, block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = load_f16_at(&src0, block_byte_base + 2u); + let qh_packed = load_u32_at(&src0, block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); + let d = f32(load_f16_at(&src0, block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(bbase + 208u)); + let d = f32(load_f16_at(&src0, bbase + 208u)); - let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); - let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); - let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); - let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); - let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); + let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); + let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 21beb9bb94..8c334817cc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx])); #endif #ifdef EXP - let res = exp(src[params.offset_src + src_idx]); + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32)); #endif #ifdef LOG let res = TYPE(log(f32(src[params.offset_src + src_idx]))); @@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0)); #endif #ifdef EXPM1 - let res = exp(src[params.offset_src + src_idx]) - 1.0; + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32) - 1.0); #endif #ifdef FLOOR let res = floor(src[params.offset_src + src_idx]); @@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; #endif #ifdef SQRT - let res = sqrt(src[params.offset_src + src_idx]); + let res = TYPE(sqrt(f32(src[params.offset_src + src_idx]))); #endif #ifdef SIN let res_f32 = sin(f32(src[params.offset_src + src_idx]));